reduce.scm 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. ; Array reductions.
  2. ; (c) Daniel Llorens - 2012-2013
  3. ; This library is free software; you can redistribute it and/or modify it under
  4. ; the terms of the GNU General Public License as published by the Free
  5. ; Software Foundation; either version 3 of the License, or (at your option) any
  6. ; later version.
  7. (define-module (ploy reduce))
  8. (import (ice-9 optargs) (srfi srfi-26) (srfi srfi-1) (srfi srfi-11)
  9. (srfi srfi-9) (srfi srfi-8) (ploy basic) (ploy assert) (ploy ploy))
  10. ; @todo look at J/repa/numpy/PDL/SAC?; tensordot/tensorsolve.
  11. ; @todo folda / foldb as verbs can be determined from the argument op. Probably
  12. ; want verbs to be applicable.
  13. ; @todo the ranks of folda / foldb should be 1 + rank of op; why does J use '_?
  14. ; special case of fold: M(a0) [R M(a1) ...], or a0 [R a1 ...]
  15. ; @todo Take verbs, like folda / foldb; see uses of over max v-norm ...
  16. ; Maybe leave the non-verb case for the time being, as an optimization.
  17. (define over/t
  18. (case-lambda
  19. ((type R a)
  20. (let ((end (tally a)))
  21. (if (zero? end)
  22. (R)
  23. (let loop ((i 1) (c (array-cell-ref a 0)))
  24. (if (< i end)
  25. (loop (+ 1 i) (R c (array-cell-ref a i)))
  26. c)))))
  27. ((type R M a)
  28. (let ((end (tally a)))
  29. (cond
  30. ((zero? end)
  31. (R))
  32. ((= 1 end)
  33. (R (M (array-cell-ref a 0))))
  34. (else
  35. (let loop ((i 1) (c (M (array-cell-ref a 0))))
  36. (if (< i end)
  37. (loop (+ 1 i) (R c (M (array-cell-ref a i))))
  38. c))))))))
  39. ; @TODO In this way it's easier to fold over >1 ranks, but it should never be slower than carrying indices.
  40. (define over/t*
  41. (case-lambda
  42. ((type R a)
  43. (if (zero? (tally a))
  44. (R)
  45. (let ((c (array-cell-ref a 0)))
  46. (array-slice-for-each 1 (lambda (a) (set! c (R c (array-cell-ref a)))) (from a (J (- (tally a) 1) 1)))
  47. c)))
  48. ((type R M a)
  49. (if (zero? (tally a))
  50. (R)
  51. (let ((c (M (array-cell-ref a 0))))
  52. (array-slice-for-each 1 (lambda (a) (set! c (R c (M (array-cell-ref a))))) (from a (J (- (tally a) 1) 1)))
  53. c)))))
  54. (export over/t*)
  55. (define over
  56. (case-lambda
  57. ((R a)
  58. (over/t (array-type* a) R a))
  59. ((R M a)
  60. (over/t (array-type* a) R M a))))
  61. (export over/t over)
  62. ; fold above ply.
  63. ; @todo (folda vector3+ #(0 0 0) #(#(0 0 1) #(0 1 0))) ???
  64. (define (folda/t type op z . a)
  65. (if (null? a)
  66. z
  67. (let ((op (if (verb? op) op (verb op)))
  68. (end (tally (car a)))
  69. ; raise the rank of z so that it can be matched with a. It's lowered later.
  70. (z (apply reshape z (cons 1 ($ z)))))
  71. ; match below the folding axis.
  72. (receive (oshape f op ri a) (apply nested-op-frames op 1 z a)
  73. (let loop ((i 0) (c (from (car a) 0)))
  74. (if (< i end)
  75. (loop (+ 1 i) (apply array-map/frame type oshape f
  76. op c (map (cut array-cell-ref <> i) (cdr a))))
  77. c))))))
  78. (define (folda op z . a)
  79. (apply folda/t (array-type* z) op z a))
  80. (export folda/t folda)
  81. ; fold below ply.
  82. (define (foldb/t type op z . a)
  83. (if (null? a)
  84. z
  85. (let ((op (if (verb? op) op (verb op)))
  86. (end (tally (car a)))
  87. ; raise the rank of z so that it can be matched with a. It's lowered later.
  88. (z (apply reshape z (cons 1 ($ z)))))
  89. ; match below the folding axis.
  90. (receive (oshape f op ri a) (apply nested-op-frames op 1 z a)
  91. (apply array-map/frame type oshape f
  92. (lambda (z . a)
  93. (let loop ((i 0) (c z))
  94. (if (< i end)
  95. (loop (+ 1 i) (apply op c (map (cut array-cell-ref <> i) a)))
  96. c)))
  97. (from (car a) 0)
  98. ; move the folding axis below the frame.
  99. (map (lambda (a ri) (rollaxis a 0 (- (rank a) 1 ri)))
  100. (cdr a) (cdr ri)))))))
  101. (define (foldb op z . a)
  102. (apply foldb/t (array-type* z) op z a))
  103. (export folda/t folda foldb/t foldb)
  104. ; -------------
  105. ; inner product
  106. ; -------------
  107. ; See more variants in test/test-reduce.scm.
  108. (define (_madd +_ *_) (verb (lambda (c a b) (+_ c (*_ a b))) '() 0 0 0))
  109. ; @todo See that we can do any order; e.g. as in the ZPL reference.
  110. ; @todo In w/rank ... 1 '_, 1 should be 1+ the rank of *. Look for examples.
  111. ; @todo scalar args.
  112. ; @todo folda! that accepts preallocated output.
  113. (define* (dot +_ *_ A B #:key type)
  114. "dot + * A B
  115. Inner product between the last axis of A and the first of B."
  116. (let ((type (or type (array-type* A))))
  117. (ply/t type (w/rank (verb (cut folda/t type (_madd +_ *_) 0 <> <>) #f '_ '_) 1 '_)
  118. A B)))
  119. (define _cmadd (verb (lambda (c a b) (+ c (* (conj a) b))) '() 0 0 0))
  120. (define* (cdot A B #:key type)
  121. "cdot A B
  122. Equivalent to (dot + (* (conj x) y) A B)."
  123. (let ((type (or type (array-type* A))))
  124. (ply/t type (w/rank (verb (cut folda/t type _cmadd 0 <> <>) #f '_ '_) 1 '_)
  125. A B)))
  126. (export dot cdot)