outer.cpp 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. // Outer product of tensors
  2. #include "stdafx.h"
  3. #include "defs.h"
  4. void
  5. eval_outer(void)
  6. {
  7. p1 = cdr(p1);
  8. push(car(p1));
  9. eval();
  10. p1 = cdr(p1);
  11. while (iscons(p1)) {
  12. push(car(p1));
  13. eval();
  14. outer();
  15. p1 = cdr(p1);
  16. }
  17. }
  18. void
  19. outer(void)
  20. {
  21. save();
  22. p2 = pop();
  23. p1 = pop();
  24. if (istensor(p1) && istensor(p2))
  25. yyouter();
  26. else {
  27. push(p1);
  28. push(p2);
  29. if (istensor(p1))
  30. tensor_times_scalar();
  31. else if (istensor(p2))
  32. scalar_times_tensor();
  33. else
  34. multiply();
  35. }
  36. restore();
  37. }
  38. void
  39. yyouter(void)
  40. {
  41. int i, j, k, ndim, nelem;
  42. ndim = p1->u.tensor->ndim + p2->u.tensor->ndim;
  43. if (ndim > MAXDIM)
  44. stop("outer: rank of result exceeds maximum");
  45. nelem = p1->u.tensor->nelem * p2->u.tensor->nelem;
  46. p3 = alloc_tensor(nelem);
  47. p3->u.tensor->ndim = ndim;
  48. for (i = 0; i < p1->u.tensor->ndim; i++)
  49. p3->u.tensor->dim[i] = p1->u.tensor->dim[i];
  50. j = i;
  51. for (i = 0; i < p2->u.tensor->ndim; i++)
  52. p3->u.tensor->dim[j + i] = p2->u.tensor->dim[i];
  53. k = 0;
  54. for (i = 0; i < p1->u.tensor->nelem; i++)
  55. for (j = 0; j < p2->u.tensor->nelem; j++) {
  56. push(p1->u.tensor->elem[i]);
  57. push(p2->u.tensor->elem[j]);
  58. multiply();
  59. p3->u.tensor->elem[k++] = pop();
  60. }
  61. push(p3);
  62. }
  63. #if SELFTEST
  64. static char *s[] = {
  65. "outer(a,b)",
  66. "a*b",
  67. "outer(a,(b1,b2))",
  68. "(a*b1,a*b2)",
  69. "outer((a1,a2),b)",
  70. "(a1*b,a2*b)",
  71. "H33=hilbert(3)",
  72. "",
  73. "H44=hilbert(4)",
  74. "",
  75. "H55=hilbert(5)",
  76. "",
  77. "H3344=outer(H33,H44)",
  78. "",
  79. "H4455=outer(H44,H55)",
  80. "",
  81. "H33444455=outer(H33,H44,H44,H55)",
  82. "",
  83. "simplify(inner(H3344,H4455)-contract(H33444455,4,5))",
  84. "0",
  85. };
  86. void
  87. test_outer(void)
  88. {
  89. test(__FILE__, s, sizeof s / sizeof (char *));
  90. }
  91. #endif