contract.cpp 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. // Contract across tensor indices
  2. #include "stdafx.h"
  3. #include "defs.h"
  4. void
  5. eval_contract(void)
  6. {
  7. push(cadr(p1));
  8. eval();
  9. if (cddr(p1) == symbol(NIL)) {
  10. push_integer(1);
  11. push_integer(2);
  12. } else {
  13. push(caddr(p1));
  14. eval();
  15. push(cadddr(p1));
  16. eval();
  17. }
  18. contract();
  19. }
  20. void
  21. contract(void)
  22. {
  23. save();
  24. yycontract();
  25. restore();
  26. }
  27. void
  28. yycontract(void)
  29. {
  30. int h, i, j, k, l, m, n, ndim, nelem;
  31. int ai[MAXDIM], an[MAXDIM];
  32. U **a, **b;
  33. p3 = pop();
  34. p2 = pop();
  35. p1 = pop();
  36. if (!istensor(p1)) {
  37. if (!iszero(p1))
  38. stop("contract: tensor expected, 1st arg is not a tensor");
  39. push(zero);
  40. return;
  41. }
  42. push(p2);
  43. l = pop_integer();
  44. push(p3);
  45. m = pop_integer();
  46. ndim = p1->u.tensor->ndim;
  47. if (l < 1 || l > ndim || m < 1 || m > ndim || l == m
  48. || p1->u.tensor->dim[l - 1] != p1->u.tensor->dim[m - 1])
  49. stop("contract: index out of range");
  50. l--;
  51. m--;
  52. n = p1->u.tensor->dim[l];
  53. // nelem is the number of elements in "b"
  54. nelem = 1;
  55. for (i = 0; i < ndim; i++)
  56. if (i != l && i != m)
  57. nelem *= p1->u.tensor->dim[i];
  58. p2 = alloc_tensor(nelem);
  59. p2->u.tensor->ndim = ndim - 2;
  60. j = 0;
  61. for (i = 0; i < ndim; i++)
  62. if (i != l && i != m)
  63. p2->u.tensor->dim[j++] = p1->u.tensor->dim[i];
  64. a = p1->u.tensor->elem;
  65. b = p2->u.tensor->elem;
  66. for (i = 0; i < ndim; i++) {
  67. ai[i] = 0;
  68. an[i] = p1->u.tensor->dim[i];
  69. }
  70. for (i = 0; i < nelem; i++) {
  71. push(zero);
  72. for (j = 0; j < n; j++) {
  73. ai[l] = j;
  74. ai[m] = j;
  75. h = 0;
  76. for (k = 0; k < ndim; k++)
  77. h = (h * an[k]) + ai[k];
  78. push(a[h]);
  79. add();
  80. }
  81. b[i] = pop();
  82. for (j = ndim - 1; j >= 0; j--) {
  83. if (j == l || j == m)
  84. continue;
  85. if (++ai[j] < an[j])
  86. break;
  87. ai[j] = 0;
  88. }
  89. }
  90. if (nelem == 1)
  91. push(b[0]);
  92. else
  93. push(p2);
  94. }