123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- ; Array reductions.
- ; (c) Daniel Llorens - 2012-2013
- ; This library is free software; you can redistribute it and/or modify it under
- ; the terms of the GNU General Public License as published by the Free
- ; Software Foundation; either version 3 of the License, or (at your option) any
- ; later version.
- (define-module (ploy reduce))
- (import (ice-9 optargs) (srfi srfi-26) (srfi srfi-1) (srfi srfi-11)
- (srfi srfi-9) (srfi srfi-8) (ploy basic) (ploy assert) (ploy ploy))
- ; @todo look at J/repa/numpy/PDL/SAC?; tensordot/tensorsolve.
- ; @todo folda / foldb as verbs can be determined from the argument op. Probably
- ; want verbs to be applicable.
- ; @todo the ranks of folda / foldb should be 1 + rank of op; why does J use '_?
- ; special case of fold: M(a0) [R M(a1) ...], or a0 [R a1 ...]
- ; @todo Take verbs, like folda / foldb; see uses of over max v-norm ...
- ; Maybe leave the non-verb case for the time being, as an optimization.
- (define over/t
- (case-lambda
- ((type R a)
- (let ((end (tally a)))
- (if (zero? end)
- (R)
- (let loop ((i 1) (c (array-from a 0)))
- (if (< i end)
- (loop (+ 1 i) (R c (array-from a i)))
- c)))))
- ((type R M a)
- (let ((end (tally a)))
- (cond
- ((zero? end)
- (R))
- ((= 1 end)
- (R (M (array-from a 0))))
- (else
- (let loop ((i 1) (c (M (array-from a 0))))
- (if (< i end)
- (loop (+ 1 i) (R c (M (array-from a i))))
- c))))))))
- ; @TODO In this way it's easier to fold over >1 ranks, but it should never be slower than carrying indices.
- (define over/t*
- (case-lambda
- ((type R a)
- (if (zero? (tally a))
- (R)
- (let ((c (array-from a 0)))
- (array-for-each-cell 1 (lambda (a) (set! c (R c (array-from a)))) (from a (J (- (tally a) 1) 1)))
- c)))
- ((type R M a)
- (if (zero? (tally a))
- (R)
- (let ((c (M (array-from a 0))))
- (array-for-each-cell 1 (lambda (a) (set! c (R c (M (array-from a))))) (from a (J (- (tally a) 1) 1)))
- c)))))
- (export over/t*)
- (define over
- (case-lambda
- ((R a)
- (over/t (array-type* a) R a))
- ((R M a)
- (over/t (array-type* a) R M a))))
- (export over/t over)
- ; fold above ply.
- ; @todo (folda vector3+ #(0 0 0) #(#(0 0 1) #(0 1 0))) ???
- (define (folda/t type op z . a)
- (if (null? a)
- z
- (let ((op (if (verb? op) op (verb op)))
- (end (tally (car a)))
- ; raise the rank of z so that it can be matched with a. It's lowered later.
- (z (apply reshape z (cons 1 ($ z)))))
- ; match below the folding axis.
- (receive (oshape f op ri a) (apply nested-op-frames op 1 z a)
- (let loop ((i 0) (c (from (car a) 0)))
- (if (< i end)
- (loop (+ 1 i) (apply array-map/frame type oshape f
- op c (map (cut array-from <> i) (cdr a))))
- c))))))
- (define (folda op z . a)
- (apply folda/t (array-type* z) op z a))
- (export folda/t folda)
- ; fold below ply.
- (define (foldb/t type op z . a)
- (if (null? a)
- z
- (let ((op (if (verb? op) op (verb op)))
- (end (tally (car a)))
- ; raise the rank of z so that it can be matched with a. It's lowered later.
- (z (apply reshape z (cons 1 ($ z)))))
- ; match below the folding axis.
- (receive (oshape f op ri a) (apply nested-op-frames op 1 z a)
- (apply array-map/frame type oshape f
- (lambda (z . a)
- (let loop ((i 0) (c z))
- (if (< i end)
- (loop (+ 1 i) (apply op c (map (cut array-from <> i) a)))
- c)))
- (from (car a) 0)
- ; move the folding axis below the frame.
- (map (lambda (a ri) (rollaxis a 0 (- (rank a) 1 ri)))
- (cdr a) (cdr ri)))))))
- (define (foldb op z . a)
- (apply foldb/t (array-type* z) op z a))
- (export folda/t folda foldb/t foldb)
- ; -------------
- ; inner product
- ; -------------
- ; See more variants in test/test-reduce.scm.
- (define (_madd +_ *_) (verb (lambda (c a b) (+_ c (*_ a b))) '() 0 0 0))
- ; @todo See that we can do any order; e.g. as in the ZPL reference.
- ; @todo In w/rank ... 1 '_, 1 should be 1+ the rank of *. Look for examples.
- ; @todo scalar args.
- ; @todo folda! that accepts preallocated output.
- (define* (dot +_ *_ A B #:key type)
- "dot + * A B
- Inner product between the last axis of A and the first of B."
- (let ((type (or type (array-type* A))))
- (ply/t type (w/rank (verb (cut folda/t type (_madd +_ *_) 0 <> <>) #f '_ '_) 1 '_)
- A B)))
- (define _cmadd (verb (lambda (c a b) (+ c (* (conj a) b))) '() 0 0 0))
- (define* (cdot A B #:key type)
- "cdot A B
- Equivalent to (dot + (* (conj x) y) A B)."
- (let ((type (or type (array-type* A))))
- (ply/t type (w/rank (verb (cut folda/t type _cmadd 0 <> <>) #f '_ '_) 1 '_)
- A B)))
- (export dot cdot)
|