inner.cpp 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. // Do the inner product of tensors.
  2. #include "stdafx.h"
  3. #include "defs.h"
  4. static void inner_f(void);
  5. void
  6. eval_inner(void)
  7. {
  8. p1 = cdr(p1);
  9. push(car(p1));
  10. eval();
  11. p1 = cdr(p1);
  12. while (iscons(p1)) {
  13. push(car(p1));
  14. eval();
  15. inner();
  16. p1 = cdr(p1);
  17. }
  18. }
  19. void
  20. inner(void)
  21. {
  22. save();
  23. p2 = pop();
  24. p1 = pop();
  25. if (istensor(p1) && istensor(p2))
  26. inner_f();
  27. else {
  28. push(p1);
  29. push(p2);
  30. if (istensor(p1))
  31. tensor_times_scalar();
  32. else if (istensor(p2))
  33. scalar_times_tensor();
  34. else
  35. multiply();
  36. }
  37. restore();
  38. }
  39. // inner product of tensors p1 and p2
  40. static void
  41. inner_f(void)
  42. {
  43. int ak, bk, i, j, k, n, ndim;
  44. U **a, **b, **c;
  45. n = p1->u.tensor->dim[p1->u.tensor->ndim - 1];
  46. if (n != p2->u.tensor->dim[0])
  47. stop("inner: tensor dimension check");
  48. ndim = p1->u.tensor->ndim + p2->u.tensor->ndim - 2;
  49. if (ndim > MAXDIM)
  50. stop("inner: rank of result exceeds maximum");
  51. a = p1->u.tensor->elem;
  52. b = p2->u.tensor->elem;
  53. //---------------------------------------------------------------------
  54. //
  55. // ak is the number of rows in tensor A
  56. //
  57. // bk is the number of columns in tensor B
  58. //
  59. // Example:
  60. //
  61. // A[3][3][4] B[4][4][3]
  62. //
  63. // 3 3 ak = 3 * 3 = 9
  64. //
  65. // 4 3 bk = 4 * 3 = 12
  66. //
  67. //---------------------------------------------------------------------
  68. ak = 1;
  69. for (i = 0; i < p1->u.tensor->ndim - 1; i++)
  70. ak *= p1->u.tensor->dim[i];
  71. bk = 1;
  72. for (i = 1; i < p2->u.tensor->ndim; i++)
  73. bk *= p2->u.tensor->dim[i];
  74. p3 = alloc_tensor(ak * bk);
  75. c = p3->u.tensor->elem;
  76. // new method copied from ginac
  77. #if 1
  78. for (i = 0; i < ak; i++) {
  79. for (j = 0; j < n; j++) {
  80. if (iszero(a[i * n + j]))
  81. continue;
  82. for (k = 0; k < bk; k++) {
  83. push(a[i * n + j]);
  84. push(b[j * bk + k]);
  85. multiply();
  86. push(c[i * bk + k]);
  87. add();
  88. c[i * bk + k] = pop();
  89. }
  90. }
  91. }
  92. #else
  93. for (i = 0; i < ak; i++) {
  94. for (j = 0; j < bk; j++) {
  95. push(zero);
  96. for (k = 0; k < n; k++) {
  97. push(a[i * n + k]);
  98. push(b[k * bk + j]);
  99. multiply();
  100. add();
  101. }
  102. c[i * bk + j] = pop();
  103. }
  104. }
  105. #endif
  106. //---------------------------------------------------------------------
  107. //
  108. // Note on understanding "k * bk + j"
  109. //
  110. // k * bk because each element of a column is bk locations apart
  111. //
  112. // + j because the beginnings of all columns are in the first bk
  113. // locations
  114. //
  115. // Example: n = 2, bk = 6
  116. //
  117. // b111 <- 1st element of 1st column
  118. // b112 <- 1st element of 2nd column
  119. // b113 <- 1st element of 3rd column
  120. // b121 <- 1st element of 4th column
  121. // b122 <- 1st element of 5th column
  122. // b123 <- 1st element of 6th column
  123. //
  124. // b211 <- 2nd element of 1st column
  125. // b212 <- 2nd element of 2nd column
  126. // b213 <- 2nd element of 3rd column
  127. // b221 <- 2nd element of 4th column
  128. // b222 <- 2nd element of 5th column
  129. // b223 <- 2nd element of 6th column
  130. //
  131. //---------------------------------------------------------------------
  132. if (ndim == 0)
  133. push(p3->u.tensor->elem[0]);
  134. else {
  135. p3->u.tensor->ndim = ndim;
  136. for (i = 0; i < p1->u.tensor->ndim - 1; i++)
  137. p3->u.tensor->dim[i] = p1->u.tensor->dim[i];
  138. j = i;
  139. for (i = 0; i < p2->u.tensor->ndim - 1; i++)
  140. p3->u.tensor->dim[j + i] = p2->u.tensor->dim[i + 1];
  141. push(p3);
  142. }
  143. }
  144. #if SELFTEST
  145. static char *s[] = {
  146. "inner(a,b)",
  147. "a*b",
  148. "inner(a,(b1,b2))",
  149. "(a*b1,a*b2)",
  150. "inner((a1,a2),b)",
  151. "(a1*b,a2*b)",
  152. "inner(((a11,a12),(a21,a22)),(x1,x2))",
  153. "(a11*x1+a12*x2,a21*x1+a22*x2)",
  154. "inner((1,2),(3,4))",
  155. "11",
  156. "inner(inner((1,2),((3,4),(5,6))),(7,8))",
  157. "219",
  158. "inner((1,2),inner(((3,4),(5,6)),(7,8)))",
  159. "219",
  160. "inner((1,2),((3,4),(5,6)),(7,8))",
  161. "219",
  162. };
  163. void
  164. test_inner(void)
  165. {
  166. test(__FILE__, s, sizeof s / sizeof (char *));
  167. }
  168. #endif