test-ffi-cblas.scm 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. ; -*- mode: scheme; coding: utf-8 -*-
  2. ; Tests for (ffi cblas).
  3. ; (c) Daniel Llorens - 2014, 2017, 2019
  4. ; This library is free software; you can redistribute it and/or modify it under
  5. ; the terms of the GNU Lesser General Public License as published by the Free
  6. ; Software Foundation; either version 3 of the License, or (at your option) any
  7. ; later version.
  8. (import (ffi cblas) (srfi srfi-64) (srfi srfi-1) (ice-9 match) (srfi srfi-8))
  9. (include "common.scm")
  10. (set! test-log-to-file #f)
  11. (test-begin "ffi-cblas")
  12. ; ---------------------------------
  13. ; Test types
  14. ; ---------------------------------
  15. (define* (test-approximate-array tag test expected err)
  16. (test-begin tag)
  17. (array-for-each (lambda (test expected) (test-approximate test expected err))
  18. test expected)
  19. (test-end tag))
  20. ; -----------------------------
  21. ; srotg drotg crotg zrotg. Some versions of CBLAS don't provide these...
  22. ; -----------------------------
  23. (when (defined? 'srotg)
  24. (test-begin "rotg")
  25. (map (match-lambda
  26. ((a b cref sref)
  27. (map (match-lambda
  28. ((rotg eps)
  29. (receive (c s) (rotg a b)
  30. (test-approximate cref c eps)
  31. (test-approximate sref s eps))))
  32. `((,cblas-srotg 1e-7)
  33. (,cblas-crotg 1e-7)
  34. (,cblas-drotg 1e-15)
  35. (,cblas-zrotg 1e-15)))))
  36. `((1. 0. 1. 0.)
  37. (0. 1. 0. 1.)
  38. (1. 1. ,(sqrt .5) ,(sqrt .5))
  39. (-1. -1. ,(sqrt .5) ,(sqrt .5))))
  40. (test-end "rotg"))
  41. ; ---------------------------------
  42. ; isamax idamax icamax izamax
  43. ; ---------------------------------
  44. (define (list-iamax a)
  45. (cdr
  46. (fold (lambda (a i mi)
  47. (let ((m (+ (magnitude (real-part a)) (magnitude (imag-part a)))))
  48. (if (> m (car mi))
  49. (cons m i)
  50. mi)))
  51. (cons -inf.0 -1)
  52. a
  53. (iota (length a)))))
  54. (define (test-iamax stype iamax make-A)
  55. (let* ((tag (format #f "~a" (procedure-name make-A)))
  56. (case-name (format #f "~a, ~a" (procedure-name iamax) tag))
  57. (A (fill-A1! (make-A stype))))
  58. (test-begin case-name)
  59. (test-equal (iamax A) (list-iamax (array->list A)))
  60. (test-end case-name)))
  61. (map (match-lambda
  62. ((stype iamax)
  63. (for-each
  64. (lambda (make-X)
  65. (test-iamax stype iamax make-X))
  66. (list make-v-compact make-v-offset make-v-strided))))
  67. `((f64 ,cblas-idamax)
  68. (f32 ,cblas-isamax)
  69. (c64 ,cblas-izamax)
  70. (c32 ,cblas-icamax)))
  71. ; ---------------------------------
  72. ; saxpy daxpy caxpy zaxpy
  73. ; ---------------------------------
  74. (define (test-axpy stype axpy! make-A make-B)
  75. (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
  76. (case-name (format #f "~a, ~a" (procedure-name axpy!) tag))
  77. (A (fill-A1! (make-A stype)))
  78. (B (fill-B1! (make-B stype))))
  79. (let ((Alist (array->list A))
  80. (Blist (array->list B)))
  81. (test-begin case-name)
  82. (axpy! 3 A B)
  83. (test-equal B (list->typed-array stype 1 (map (lambda (a b) (+ (* 3 a) b)) Alist Blist)))
  84. (test-equal A (list->typed-array stype 1 Alist))
  85. (axpy! 1.9 A B)
  86. (test-approximate-array "approximate array"
  87. B (list->typed-array stype 1 (map (lambda (a b) (+ (* a (+ 3 1.9)) b)) Alist Blist)) 1e-14)
  88. (test-equal A (list->typed-array stype 1 Alist))
  89. (test-end case-name))))
  90. (map
  91. (match-lambda
  92. ((stype axpy!)
  93. (for-each
  94. (lambda (make-X)
  95. (for-each
  96. (lambda (make-Y)
  97. (test-axpy stype axpy! make-X make-Y))
  98. (list make-v-compact make-v-offset make-v-strided)))
  99. (list make-v-compact make-v-offset make-v-strided))))
  100. `((f64 ,cblas-daxpy!)
  101. (f32 ,cblas-saxpy!)
  102. (c64 ,cblas-zaxpy!)
  103. (c32 ,cblas-caxpy!)))
  104. ; ---------------------------------
  105. ; scopy dcopy ccopy zcopy
  106. ; ---------------------------------
  107. (define (test-copy stype copy! make-A make-B)
  108. (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
  109. (case-name (format #f "~a, ~a" (procedure-name copy!) tag))
  110. (A (fill-A1! (make-A stype)))
  111. (B (fill-B1! (make-B stype))))
  112. (let ((Alist (array->list A))
  113. (Blist (array->list B)))
  114. (test-begin case-name)
  115. (copy! A B)
  116. (test-equal B (list->typed-array stype 1 Alist))
  117. (test-equal A (list->typed-array stype 1 Alist))
  118. (test-end case-name))))
  119. (map
  120. (match-lambda
  121. ((stype copy!)
  122. (for-each
  123. (lambda (make-X)
  124. (for-each
  125. (lambda (make-Y)
  126. (test-copy stype copy! make-X make-Y))
  127. (list make-v-compact make-v-offset make-v-strided)))
  128. (list make-v-compact make-v-offset make-v-strided))))
  129. `((f64 ,cblas-dcopy!)
  130. (f32 ,cblas-scopy!)
  131. (c64 ,cblas-zcopy!)
  132. (c32 ,cblas-ccopy!)))
  133. ; ---------------------------------
  134. ; scopy dcopy ccopy zcopy
  135. ; ---------------------------------
  136. (define (test-swap stype swap! make-A make-B)
  137. (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
  138. (case-name (format #f "~a, ~a" (procedure-name swap!) tag))
  139. (A (fill-A1! (make-A stype)))
  140. (B (fill-B1! (make-B stype))))
  141. (let ((Alist (array->list A))
  142. (Blist (array->list B)))
  143. (test-begin case-name)
  144. (swap! A B)
  145. (test-equal B (list->typed-array stype 1 Alist))
  146. (test-equal A (list->typed-array stype 1 Blist))
  147. (test-end case-name))))
  148. (map
  149. (match-lambda
  150. ((stype swap!)
  151. (for-each
  152. (lambda (make-X)
  153. (for-each
  154. (lambda (make-Y)
  155. (test-swap stype swap! make-X make-Y))
  156. (list make-v-compact make-v-offset make-v-strided)))
  157. (list make-v-compact make-v-offset make-v-strided))))
  158. `((f64 ,cblas-dswap!)
  159. (f32 ,cblas-sswap!)
  160. (c64 ,cblas-zswap!)
  161. (c32 ,cblas-cswap!)))
  162. ; ---------------------------------
  163. ; sgemv dgemv cgemv zgemv
  164. ; ---------------------------------
  165. (define (ref-gemv! alpha A X beta Y)
  166. (match (array-dimensions A)
  167. ((M N)
  168. (do ((i 0 (+ i 1))) ((= i M))
  169. (array-set! Y (* beta (array-ref Y i)) i)
  170. (do ((j 0 (+ j 1))) ((= j N))
  171. (array-set! Y (+ (array-ref Y i) (* alpha (array-ref A i j) (array-ref X j))) i)))
  172. Y)))
  173. (define (test-gemv stype gemv! make-A make-X make-Y)
  174. (let* ((tag (format #f "~a:~a:~a" (procedure-name make-A) (procedure-name make-X) (procedure-name make-Y)))
  175. (case-name (format #f "~a, ~a" (procedure-name gemv!) tag))
  176. (A (fill-A2! (make-A stype)))
  177. (X (fill-A1! (make-X stype)))
  178. (Y (fill-B1! (make-Y stype))))
  179. (let ((A1 (array-copy A))
  180. (X1 (array-copy X)))
  181. (test-begin case-name)
  182. (let ((Y1 (array-copy Y))
  183. (Y2 (array-copy Y)))
  184. (gemv! 2. A CblasNoTrans X 3. Y1) ; TODO Test other values of transA
  185. (ref-gemv! 2. A X 3. Y2)
  186. (test-equal Y1 Y2)
  187. (test-equal A A1)
  188. (test-equal X X1)
  189. (test-end case-name)))))
  190. (for-each
  191. (match-lambda
  192. ((stype gemv!)
  193. (for-each
  194. (match-lambda ((make-A make-X make-Y)
  195. (test-gemv stype gemv! make-A make-X make-Y)))
  196. (list-product
  197. (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
  198. (list make-v-compact make-v-offset make-v-strided)
  199. (list make-v-compact make-v-offset make-v-strided)))))
  200. `((f64 ,cblas-dgemv!)
  201. (f32 ,cblas-sgemv!)
  202. (c64 ,cblas-zgemv!)
  203. (c32 ,cblas-cgemv!)))
  204. ; ---------------------------------
  205. ; TODO snrm2, dnrm2, cnrm2, znrm2
  206. ; ---------------------------------
  207. ; ---------------------------------
  208. ; TODO sasum, dasum, casum, zasum
  209. ; ---------------------------------
  210. ; ---------------------------------
  211. ; TODO sscal, dscal, cscal, zscal, csscal, zdscal
  212. ; ---------------------------------
  213. ; ---------------------------------
  214. ; TODO sger, dger, cgeru, cgerc, zgeru, zgerc
  215. ; ---------------------------------
  216. ; ---------------------------------
  217. ; TODO isamax idamax icamax izamax
  218. ; ---------------------------------
  219. ; ---------------------------------
  220. ; sgemm dgemm cgemm zgemm
  221. ; ---------------------------------
  222. ; alpha * sum_k(A_{ik}*B_{kj}) + beta * C_{ij} -> C_{ij}
  223. (define (ref-gemm! alpha A transA B transB beta C)
  224. (let* ((A (if ((@@ (ffi cblas) tr?) transA) (transpose-array A 1 0) A))
  225. (B (if ((@@ (ffi cblas) tr?) transB) (transpose-array B 1 0) B))
  226. (M (first (array-dimensions C)))
  227. (N (second (array-dimensions C)))
  228. (K (first (array-dimensions B))))
  229. (do ((i 0 (+ i 1))) ((= i M))
  230. (do ((j 0 (+ j 1))) ((= j N))
  231. (array-set! C (* beta (array-ref C i j)) i j)
  232. (do ((k 0 (+ k 1))) ((= k K))
  233. (array-set! C (+ (array-ref C i j) (* alpha (array-ref A i k) (array-ref B k j))) i j))))))
  234. (define (test-gemm tag gemm! alpha A transA B transB beta C)
  235. (let ((C1 (array-copy C))
  236. (C2 (array-copy C))
  237. (AA (array-copy A))
  238. (BB (array-copy B)))
  239. (gemm! alpha A transA B transB beta C1)
  240. (ref-gemm! alpha A transA B transB beta C2)
  241. ;; (test-approximate-array tag C1 C2 1e-15) ; TODO as a single test.
  242. (test-begin tag)
  243. (test-equal C1 C2)
  244. (test-end tag)))
  245. (for-each
  246. (match-lambda
  247. ((stype gemm!)
  248. ; some extra tests with non-square matrices.
  249. (let ((A (fill-A2! (make-typed-array stype *unspecified* 4 3)))
  250. (B (fill-A2! (make-typed-array stype *unspecified* 3 5)))
  251. (C (fill-A2! (make-typed-array stype *unspecified* 4 5))))
  252. (test-gemm "gemm-1" gemm! 1. A CblasNoTrans B CblasNoTrans 1. C)
  253. (test-gemm "gemm-2" gemm! 1. A CblasTrans C CblasNoTrans 1. B)
  254. (test-gemm "gemm-3" gemm! 1. C CblasNoTrans B CblasTrans 1. A))
  255. (let ((A (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 4 3) 1 0)))
  256. (B (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 3 5) 1 0)))
  257. (C (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 4 5) 1 0))))
  258. (test-gemm "gemm-4" cblas-dgemm! 1. A CblasTrans B CblasTrans 1. (transpose-array C 1 0))
  259. (test-gemm "gemm-5" cblas-dgemm! 1. A CblasNoTrans C CblasTrans 1. (transpose-array B 1 0))
  260. (test-gemm "gemm-6" cblas-dgemm! 1. C CblasTrans B CblasNoTrans 1. (transpose-array A 1 0)))
  261. (for-each
  262. (match-lambda ((make-A make-B make-C transA transB)
  263. (test-gemm (format #f "gemm:~a:~a:~a:~a:~a:~a" stype (procedure-name make-A)
  264. (procedure-name make-B) (procedure-name make-C)
  265. transA transB)
  266. gemm! 3. (fill-A2! (make-A stype)) transA
  267. (fill-A2! (make-B stype)) transB
  268. 2. (fill-A2! (make-C stype)))))
  269. (list-product
  270. (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
  271. (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
  272. (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
  273. ; TODO Conj, etc. for c32/c64.
  274. (list CblasTrans CblasNoTrans)
  275. (list CblasTrans CblasNoTrans)))))
  276. `((f32 ,cblas-sgemm!)
  277. (f64 ,cblas-dgemm!)
  278. (c32 ,cblas-cgemm!)
  279. (c64 ,cblas-zgemm!)))
  280. (define error-count (test-runner-fail-count (test-runner-current)))
  281. (test-end "ffi-cblas")
  282. (exit error-count)