123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- ; -*- mode: scheme; coding: utf-8 -*-
- ; Tests for (ffi cblas).
- ; (c) Daniel Llorens - 2014, 2017, 2019
- ; This library is free software; you can redistribute it and/or modify it under
- ; the terms of the GNU Lesser General Public License as published by the Free
- ; Software Foundation; either version 3 of the License, or (at your option) any
- ; later version.
- (import (ffi cblas) (srfi srfi-64) (srfi srfi-1) (ice-9 match) (srfi srfi-8))
- (include "common.scm")
- (test-begin "ffi-cblas")
- ; ---------------------------------
- ; Test types
- ; ---------------------------------
- (define* (test-approximate-array tag expected val err)
- (let ((tag (if (symbol? tag) (symbol->string tag) tag)))
- (test-begin tag)
- (array-for-each (lambda (expected val)
- (if (and (real? expected) (real? val))
- (test-approximate expected val err)
- (test-approximate 0. (magnitude (- expected val)) err)))
- expected val)
- (test-end tag)))
- ; -----------------------------
- ; srotg drotg crotg zrotg. Some versions of CBLAS don't provide these...
- ; -----------------------------
- (when (defined? 'srotg)
- (test-begin "rotg")
- (map (match-lambda
- ((a b cref sref)
- (map (match-lambda
- ((rotg eps)
- (receive (c s) (rotg a b)
- (test-approximate cref c eps)
- (test-approximate sref s eps))))
- `((,cblas-srotg 1e-7)
- (,cblas-crotg 1e-7)
- (,cblas-drotg 1e-15)
- (,cblas-zrotg 1e-15)))))
- `((1. 0. 1. 0.)
- (0. 1. 0. 1.)
- (1. 1. ,(sqrt .5) ,(sqrt .5))
- (-1. -1. ,(sqrt .5) ,(sqrt .5))))
- (test-end "rotg"))
- ; ---------------------------------
- ; isamax idamax icamax izamax
- ; ---------------------------------
- (define (list-iamax a)
- (cdr
- (fold (lambda (a i mi)
- (let ((m (+ (magnitude (real-part a)) (magnitude (imag-part a)))))
- (if (> m (car mi))
- (cons m i)
- mi)))
- (cons -inf.0 -1)
- a
- (iota (length a)))))
- (define (test-iamax stype iamax make-A)
- (let* ((tag (format #f "~a" (procedure-name make-A)))
- (case-name (format #f "~a, ~a" (procedure-name iamax) tag))
- (A (fill-A1! (make-A stype))))
- (test-begin case-name)
- (test-equal (iamax A) (list-iamax (array->list A)))
- (test-end case-name)))
- (map (match-lambda
- ((stype iamax)
- (for-each
- (lambda (make-X)
- (test-iamax stype iamax make-X))
- (list make-v-compact make-v-offset make-v-strided))))
- `((f64 ,cblas-idamax)
- (f32 ,cblas-isamax)
- (c64 ,cblas-izamax)
- (c32 ,cblas-icamax)))
- ; ---------------------------------
- ; saxpy daxpy caxpy zaxpy
- ; ---------------------------------
- (define (test-axpy stype axpy! make-A make-B)
- (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
- (case-name (format #f "~a, ~a" (procedure-name axpy!) tag))
- (A (fill-A1! (make-A stype)))
- (B (fill-B1! (make-B stype))))
- (let ((Alist (array->list A))
- (Blist (array->list B)))
- (test-begin case-name)
- (axpy! 3 A B)
- (test-equal B (list->typed-array stype 1 (map (lambda (a b) (+ (* 3 a) b)) Alist Blist)))
- (test-equal A (list->typed-array stype 1 Alist))
- (axpy! 1.9 A B)
- (test-approximate-array "approximate array"
- B (list->typed-array stype 1 (map (lambda (a b) (+ (* a (+ 3 1.9)) b)) Alist Blist)) 1e-14)
- (test-equal A (list->typed-array stype 1 Alist))
- (test-end case-name))))
- (map
- (match-lambda
- ((stype axpy!)
- (for-each
- (lambda (make-X)
- (for-each
- (lambda (make-Y)
- (test-axpy stype axpy! make-X make-Y))
- (list make-v-compact make-v-offset make-v-strided)))
- (list make-v-compact make-v-offset make-v-strided))))
- `((f64 ,cblas-daxpy!)
- (f32 ,cblas-saxpy!)
- (c64 ,cblas-zaxpy!)
- (c32 ,cblas-caxpy!)))
- ; ---------------------------------
- ; scopy dcopy ccopy zcopy
- ; ---------------------------------
- (define (test-copy stype copy! make-A make-B)
- (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
- (case-name (format #f "~a, ~a" (procedure-name copy!) tag))
- (A (fill-A1! (make-A stype)))
- (B (fill-B1! (make-B stype))))
- (let ((Alist (array->list A))
- (Blist (array->list B)))
- (test-begin case-name)
- (copy! A B)
- (test-equal B (list->typed-array stype 1 Alist))
- (test-equal A (list->typed-array stype 1 Alist))
- (test-end case-name))))
- (map
- (match-lambda
- ((stype copy!)
- (for-each
- (lambda (make-X)
- (for-each
- (lambda (make-Y)
- (test-copy stype copy! make-X make-Y))
- (list make-v-compact make-v-offset make-v-strided)))
- (list make-v-compact make-v-offset make-v-strided))))
- `((f64 ,cblas-dcopy!)
- (f32 ,cblas-scopy!)
- (c64 ,cblas-zcopy!)
- (c32 ,cblas-ccopy!)))
- ; ---------------------------------
- ; scopy dcopy ccopy zcopy
- ; ---------------------------------
- (define (test-swap stype swap! make-A make-B)
- (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
- (case-name (format #f "~a, ~a" (procedure-name swap!) tag))
- (A (fill-A1! (make-A stype)))
- (B (fill-B1! (make-B stype))))
- (let ((Alist (array->list A))
- (Blist (array->list B)))
- (test-begin case-name)
- (swap! A B)
- (test-equal B (list->typed-array stype 1 Alist))
- (test-equal A (list->typed-array stype 1 Blist))
- (test-end case-name))))
- (map
- (match-lambda
- ((stype swap!)
- (for-each
- (lambda (make-X)
- (for-each
- (lambda (make-Y)
- (test-swap stype swap! make-X make-Y))
- (list make-v-compact make-v-offset make-v-strided)))
- (list make-v-compact make-v-offset make-v-strided))))
- `((f64 ,cblas-dswap!)
- (f32 ,cblas-sswap!)
- (c64 ,cblas-zswap!)
- (c32 ,cblas-cswap!)))
- ; ---------------------------------
- ; sgemv dgemv cgemv zgemv
- ; ---------------------------------
- (define (ref-gemv! alpha A X beta Y)
- (match (array-dimensions A)
- ((M N)
- (do ((i 0 (+ i 1))) ((= i M))
- (array-set! Y (* beta (array-ref Y i)) i)
- (do ((j 0 (+ j 1))) ((= j N))
- (array-set! Y (+ (array-ref Y i) (* alpha (array-ref A i j) (array-ref X j))) i)))
- Y)))
- (define (test-gemv stype gemv! make-A make-X make-Y)
- (let* ((tag (format #f "~a:~a:~a" (procedure-name make-A) (procedure-name make-X) (procedure-name make-Y)))
- (case-name (format #f "~a, ~a" (procedure-name gemv!) tag))
- (A (fill-A2! (make-A stype)))
- (X (fill-A1! (make-X stype)))
- (Y (fill-B1! (make-Y stype))))
- (let ((A1 (array-copy A))
- (X1 (array-copy X)))
- (test-begin case-name)
- (let ((Y1 (array-copy Y))
- (Y2 (array-copy Y)))
- (gemv! 2. A CblasNoTrans X 3. Y1) ; TODO Test other values of transA
- (ref-gemv! 2. A X 3. Y2)
- (test-equal Y1 Y2)
- (test-equal A A1)
- (test-equal X X1)
- (test-end case-name)))))
- (for-each
- (match-lambda
- ((stype gemv!)
- (for-each
- (match-lambda ((make-A make-X make-Y)
- (test-gemv stype gemv! make-A make-X make-Y)))
- (list-product
- (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
- (list make-v-compact make-v-offset make-v-strided)
- (list make-v-compact make-v-offset make-v-strided)))))
- `((f64 ,cblas-dgemv!)
- (f32 ,cblas-sgemv!)
- (c64 ,cblas-zgemv!)
- (c32 ,cblas-cgemv!)))
- ; ---------------------------------
- ; TODO snrm2, dnrm2, cnrm2, znrm2
- ; ---------------------------------
- ; ---------------------------------
- ; TODO sasum, dasum, casum, zasum
- ; ---------------------------------
- ; ---------------------------------
- ; TODO sscal, dscal, cscal, zscal, csscal, zdscal
- ; ---------------------------------
- ; ---------------------------------
- ; TODO sger, dger, cgeru, cgerc, zgeru, zgerc
- ; ---------------------------------
- ; ---------------------------------
- ; TODO isamax idamax icamax izamax
- ; ---------------------------------
- ; ---------------------------------
- ; sgemm dgemm cgemm zgemm
- ; ---------------------------------
- (define (test-gemm tag gemm! alpha A transA B transB beta C)
- ;; alpha * sum_k(A_{ik}*B_{kj}) + beta * C_{ij} -> C_{ij}
- (define (ref-gemm! alpha A transA B transB beta C)
- (let* ((A (if ((@@ (ffi cblas) tr?) transA) (transpose-array A 1 0) A))
- (B (if ((@@ (ffi cblas) tr?) transB) (transpose-array B 1 0) B))
- (M (first (array-dimensions C)))
- (N (second (array-dimensions C)))
- (K (first (array-dimensions B))))
- (do ((i 0 (+ i 1))) ((= i M))
- (do ((j 0 (+ j 1))) ((= j N))
- (array-set! C (* beta (array-ref C i j)) i j)
- (do ((k 0 (+ k 1))) ((= k K))
- (array-set! C (+ (array-ref C i j) (* alpha (array-ref A i k) (array-ref B k j))) i j))))))
- (let ((C1 (array-copy C))
- (C2 (array-copy C))
- (AA (array-copy A))
- (BB (array-copy B)))
- (gemm! alpha A transA B transB beta C1)
- (ref-gemm! alpha A transA B transB beta C2)
- ;; (test-approximate-array tag C1 C2 1e-15) ; TODO as a single test.
- (test-begin tag)
- (test-equal C1 C2)
- (test-end tag)))
- (for-each
- (match-lambda
- ((stype gemm!)
- ; some extra tests with non-square matrices.
- (let ((A (fill-A2! (make-typed-array stype *unspecified* 4 3)))
- (B (fill-A2! (make-typed-array stype *unspecified* 3 5)))
- (C (fill-A2! (make-typed-array stype *unspecified* 4 5))))
- (test-gemm "gemm-1" gemm! 1. A CblasNoTrans B CblasNoTrans 1. C)
- (test-gemm "gemm-2" gemm! 1. A CblasTrans C CblasNoTrans 1. B)
- (test-gemm "gemm-3" gemm! 1. C CblasNoTrans B CblasTrans 1. A))
- (let ((A (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 4 3) 1 0)))
- (B (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 3 5) 1 0)))
- (C (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 4 5) 1 0))))
- (test-gemm "gemm-4" cblas-dgemm! 1. A CblasTrans B CblasTrans 1. (transpose-array C 1 0))
- (test-gemm "gemm-5" cblas-dgemm! 1. A CblasNoTrans C CblasTrans 1. (transpose-array B 1 0))
- (test-gemm "gemm-6" cblas-dgemm! 1. C CblasTrans B CblasNoTrans 1. (transpose-array A 1 0)))
- (for-each
- (match-lambda ((make-A make-B make-C transA transB)
- (test-gemm (format #f "gemm:~a:~a:~a:~a:~a:~a" stype (procedure-name make-A)
- (procedure-name make-B) (procedure-name make-C)
- transA transB)
- gemm! 3. (fill-A2! (make-A stype)) transA
- (fill-A2! (make-B stype)) transB
- 2. (fill-A2! (make-C stype)))))
- (list-product
- (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
- (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
- (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
- ; TODO Conj, etc. for c32/c64.
- (list CblasTrans CblasNoTrans)
- (list CblasTrans CblasNoTrans)))))
- `((f32 ,cblas-sgemm!)
- (f64 ,cblas-dgemm!)
- (c32 ,cblas-cgemm!)
- (c64 ,cblas-zgemm!)))
- (define error-count (test-runner-fail-count (test-runner-current)))
- (test-end "ffi-cblas")
- (exit error-count)
|