test-ffi-cblas.scm 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. ; -*- mode: scheme; coding: utf-8 -*-
  2. ; Tests for (ffi cblas).
  3. ; (c) Daniel Llorens - 2014, 2017, 2019
  4. ; This library is free software; you can redistribute it and/or modify it under
  5. ; the terms of the GNU Lesser General Public License as published by the Free
  6. ; Software Foundation; either version 3 of the License, or (at your option) any
  7. ; later version.
  8. (import (ffi cblas) (srfi srfi-64) (srfi srfi-1) (ice-9 match) (srfi srfi-8))
  9. (include "common.scm")
  10. (test-begin "ffi-cblas")
  11. ; ---------------------------------
  12. ; Test types
  13. ; ---------------------------------
  14. (define* (test-approximate-array tag expected val err)
  15. (let ((tag (if (symbol? tag) (symbol->string tag) tag)))
  16. (test-begin tag)
  17. (array-for-each (lambda (expected val)
  18. (if (and (real? expected) (real? val))
  19. (test-approximate expected val err)
  20. (test-approximate 0. (magnitude (- expected val)) err)))
  21. expected val)
  22. (test-end tag)))
  23. ; -----------------------------
  24. ; srotg drotg crotg zrotg. Some versions of CBLAS don't provide these...
  25. ; -----------------------------
  26. (when (defined? 'srotg)
  27. (test-begin "rotg")
  28. (map (match-lambda
  29. ((a b cref sref)
  30. (map (match-lambda
  31. ((rotg eps)
  32. (receive (c s) (rotg a b)
  33. (test-approximate cref c eps)
  34. (test-approximate sref s eps))))
  35. `((,cblas-srotg 1e-7)
  36. (,cblas-crotg 1e-7)
  37. (,cblas-drotg 1e-15)
  38. (,cblas-zrotg 1e-15)))))
  39. `((1. 0. 1. 0.)
  40. (0. 1. 0. 1.)
  41. (1. 1. ,(sqrt .5) ,(sqrt .5))
  42. (-1. -1. ,(sqrt .5) ,(sqrt .5))))
  43. (test-end "rotg"))
  44. ; ---------------------------------
  45. ; isamax idamax icamax izamax
  46. ; ---------------------------------
  47. (define (list-iamax a)
  48. (cdr
  49. (fold (lambda (a i mi)
  50. (let ((m (+ (magnitude (real-part a)) (magnitude (imag-part a)))))
  51. (if (> m (car mi))
  52. (cons m i)
  53. mi)))
  54. (cons -inf.0 -1)
  55. a
  56. (iota (length a)))))
  57. (define (test-iamax stype iamax make-A)
  58. (let* ((tag (format #f "~a" (procedure-name make-A)))
  59. (case-name (format #f "~a, ~a" (procedure-name iamax) tag))
  60. (A (fill-A1! (make-A stype))))
  61. (test-begin case-name)
  62. (test-equal (iamax A) (list-iamax (array->list A)))
  63. (test-end case-name)))
  64. (map (match-lambda
  65. ((stype iamax)
  66. (for-each
  67. (lambda (make-X)
  68. (test-iamax stype iamax make-X))
  69. (list make-v-compact make-v-offset make-v-strided))))
  70. `((f64 ,cblas-idamax)
  71. (f32 ,cblas-isamax)
  72. (c64 ,cblas-izamax)
  73. (c32 ,cblas-icamax)))
  74. ; ---------------------------------
  75. ; saxpy daxpy caxpy zaxpy
  76. ; ---------------------------------
  77. (define (test-axpy stype axpy! make-A make-B)
  78. (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
  79. (case-name (format #f "~a, ~a" (procedure-name axpy!) tag))
  80. (A (fill-A1! (make-A stype)))
  81. (B (fill-B1! (make-B stype))))
  82. (let ((Alist (array->list A))
  83. (Blist (array->list B)))
  84. (test-begin case-name)
  85. (axpy! 3 A B)
  86. (test-equal B (list->typed-array stype 1 (map (lambda (a b) (+ (* 3 a) b)) Alist Blist)))
  87. (test-equal A (list->typed-array stype 1 Alist))
  88. (axpy! 1.9 A B)
  89. (test-approximate-array "approximate array"
  90. B (list->typed-array stype 1 (map (lambda (a b) (+ (* a (+ 3 1.9)) b)) Alist Blist)) 1e-14)
  91. (test-equal A (list->typed-array stype 1 Alist))
  92. (test-end case-name))))
  93. (map
  94. (match-lambda
  95. ((stype axpy!)
  96. (for-each
  97. (lambda (make-X)
  98. (for-each
  99. (lambda (make-Y)
  100. (test-axpy stype axpy! make-X make-Y))
  101. (list make-v-compact make-v-offset make-v-strided)))
  102. (list make-v-compact make-v-offset make-v-strided))))
  103. `((f64 ,cblas-daxpy!)
  104. (f32 ,cblas-saxpy!)
  105. (c64 ,cblas-zaxpy!)
  106. (c32 ,cblas-caxpy!)))
  107. ; ---------------------------------
  108. ; scopy dcopy ccopy zcopy
  109. ; ---------------------------------
  110. (define (test-copy stype copy! make-A make-B)
  111. (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
  112. (case-name (format #f "~a, ~a" (procedure-name copy!) tag))
  113. (A (fill-A1! (make-A stype)))
  114. (B (fill-B1! (make-B stype))))
  115. (let ((Alist (array->list A))
  116. (Blist (array->list B)))
  117. (test-begin case-name)
  118. (copy! A B)
  119. (test-equal B (list->typed-array stype 1 Alist))
  120. (test-equal A (list->typed-array stype 1 Alist))
  121. (test-end case-name))))
  122. (map
  123. (match-lambda
  124. ((stype copy!)
  125. (for-each
  126. (lambda (make-X)
  127. (for-each
  128. (lambda (make-Y)
  129. (test-copy stype copy! make-X make-Y))
  130. (list make-v-compact make-v-offset make-v-strided)))
  131. (list make-v-compact make-v-offset make-v-strided))))
  132. `((f64 ,cblas-dcopy!)
  133. (f32 ,cblas-scopy!)
  134. (c64 ,cblas-zcopy!)
  135. (c32 ,cblas-ccopy!)))
  136. ; ---------------------------------
  137. ; scopy dcopy ccopy zcopy
  138. ; ---------------------------------
  139. (define (test-swap stype swap! make-A make-B)
  140. (let* ((tag (format #f "~a:~a" (procedure-name make-A) (procedure-name make-B)))
  141. (case-name (format #f "~a, ~a" (procedure-name swap!) tag))
  142. (A (fill-A1! (make-A stype)))
  143. (B (fill-B1! (make-B stype))))
  144. (let ((Alist (array->list A))
  145. (Blist (array->list B)))
  146. (test-begin case-name)
  147. (swap! A B)
  148. (test-equal B (list->typed-array stype 1 Alist))
  149. (test-equal A (list->typed-array stype 1 Blist))
  150. (test-end case-name))))
  151. (map
  152. (match-lambda
  153. ((stype swap!)
  154. (for-each
  155. (lambda (make-X)
  156. (for-each
  157. (lambda (make-Y)
  158. (test-swap stype swap! make-X make-Y))
  159. (list make-v-compact make-v-offset make-v-strided)))
  160. (list make-v-compact make-v-offset make-v-strided))))
  161. `((f64 ,cblas-dswap!)
  162. (f32 ,cblas-sswap!)
  163. (c64 ,cblas-zswap!)
  164. (c32 ,cblas-cswap!)))
  165. ; ---------------------------------
  166. ; sgemv dgemv cgemv zgemv
  167. ; ---------------------------------
  168. (define (ref-gemv! alpha A X beta Y)
  169. (match (array-dimensions A)
  170. ((M N)
  171. (do ((i 0 (+ i 1))) ((= i M))
  172. (array-set! Y (* beta (array-ref Y i)) i)
  173. (do ((j 0 (+ j 1))) ((= j N))
  174. (array-set! Y (+ (array-ref Y i) (* alpha (array-ref A i j) (array-ref X j))) i)))
  175. Y)))
  176. (define (test-gemv stype gemv! make-A make-X make-Y)
  177. (let* ((tag (format #f "~a:~a:~a" (procedure-name make-A) (procedure-name make-X) (procedure-name make-Y)))
  178. (case-name (format #f "~a, ~a" (procedure-name gemv!) tag))
  179. (A (fill-A2! (make-A stype)))
  180. (X (fill-A1! (make-X stype)))
  181. (Y (fill-B1! (make-Y stype))))
  182. (let ((A1 (array-copy A))
  183. (X1 (array-copy X)))
  184. (test-begin case-name)
  185. (let ((Y1 (array-copy Y))
  186. (Y2 (array-copy Y)))
  187. (gemv! 2. A CblasNoTrans X 3. Y1) ; TODO Test other values of transA
  188. (ref-gemv! 2. A X 3. Y2)
  189. (test-equal Y1 Y2)
  190. (test-equal A A1)
  191. (test-equal X X1)
  192. (test-end case-name)))))
  193. (for-each
  194. (match-lambda
  195. ((stype gemv!)
  196. (for-each
  197. (match-lambda ((make-A make-X make-Y)
  198. (test-gemv stype gemv! make-A make-X make-Y)))
  199. (list-product
  200. (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
  201. (list make-v-compact make-v-offset make-v-strided)
  202. (list make-v-compact make-v-offset make-v-strided)))))
  203. `((f64 ,cblas-dgemv!)
  204. (f32 ,cblas-sgemv!)
  205. (c64 ,cblas-zgemv!)
  206. (c32 ,cblas-cgemv!)))
  207. ; ---------------------------------
  208. ; TODO snrm2, dnrm2, cnrm2, znrm2
  209. ; ---------------------------------
  210. ; ---------------------------------
  211. ; TODO sasum, dasum, casum, zasum
  212. ; ---------------------------------
  213. ; ---------------------------------
  214. ; TODO sscal, dscal, cscal, zscal, csscal, zdscal
  215. ; ---------------------------------
  216. ; ---------------------------------
  217. ; TODO sger, dger, cgeru, cgerc, zgeru, zgerc
  218. ; ---------------------------------
  219. ; ---------------------------------
  220. ; TODO isamax idamax icamax izamax
  221. ; ---------------------------------
  222. ; ---------------------------------
  223. ; sgemm dgemm cgemm zgemm
  224. ; ---------------------------------
  225. (define (test-gemm tag gemm! alpha A transA B transB beta C)
  226. ;; alpha * sum_k(A_{ik}*B_{kj}) + beta * C_{ij} -> C_{ij}
  227. (define (ref-gemm! alpha A transA B transB beta C)
  228. (let* ((A (if ((@@ (ffi cblas) tr?) transA) (transpose-array A 1 0) A))
  229. (B (if ((@@ (ffi cblas) tr?) transB) (transpose-array B 1 0) B))
  230. (M (first (array-dimensions C)))
  231. (N (second (array-dimensions C)))
  232. (K (first (array-dimensions B))))
  233. (do ((i 0 (+ i 1))) ((= i M))
  234. (do ((j 0 (+ j 1))) ((= j N))
  235. (array-set! C (* beta (array-ref C i j)) i j)
  236. (do ((k 0 (+ k 1))) ((= k K))
  237. (array-set! C (+ (array-ref C i j) (* alpha (array-ref A i k) (array-ref B k j))) i j))))))
  238. (let ((C1 (array-copy C))
  239. (C2 (array-copy C))
  240. (AA (array-copy A))
  241. (BB (array-copy B)))
  242. (gemm! alpha A transA B transB beta C1)
  243. (ref-gemm! alpha A transA B transB beta C2)
  244. ;; (test-approximate-array tag C1 C2 1e-15) ; TODO as a single test.
  245. (test-begin tag)
  246. (test-equal C1 C2)
  247. (test-end tag)))
  248. (for-each
  249. (match-lambda
  250. ((stype gemm!)
  251. ; some extra tests with non-square matrices.
  252. (let ((A (fill-A2! (make-typed-array stype *unspecified* 4 3)))
  253. (B (fill-A2! (make-typed-array stype *unspecified* 3 5)))
  254. (C (fill-A2! (make-typed-array stype *unspecified* 4 5))))
  255. (test-gemm "gemm-1" gemm! 1. A CblasNoTrans B CblasNoTrans 1. C)
  256. (test-gemm "gemm-2" gemm! 1. A CblasTrans C CblasNoTrans 1. B)
  257. (test-gemm "gemm-3" gemm! 1. C CblasNoTrans B CblasTrans 1. A))
  258. (let ((A (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 4 3) 1 0)))
  259. (B (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 3 5) 1 0)))
  260. (C (fill-A2! (transpose-array (make-typed-array 'f64 *unspecified* 4 5) 1 0))))
  261. (test-gemm "gemm-4" cblas-dgemm! 1. A CblasTrans B CblasTrans 1. (transpose-array C 1 0))
  262. (test-gemm "gemm-5" cblas-dgemm! 1. A CblasNoTrans C CblasTrans 1. (transpose-array B 1 0))
  263. (test-gemm "gemm-6" cblas-dgemm! 1. C CblasTrans B CblasNoTrans 1. (transpose-array A 1 0)))
  264. (for-each
  265. (match-lambda ((make-A make-B make-C transA transB)
  266. (test-gemm (format #f "gemm:~a:~a:~a:~a:~a:~a" stype (procedure-name make-A)
  267. (procedure-name make-B) (procedure-name make-C)
  268. transA transB)
  269. gemm! 3. (fill-A2! (make-A stype)) transA
  270. (fill-A2! (make-B stype)) transB
  271. 2. (fill-A2! (make-C stype)))))
  272. (list-product
  273. (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
  274. (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
  275. (list make-M-c-order make-M-fortran-order make-M-offset make-M-strided)
  276. ; TODO Conj, etc. for c32/c64.
  277. (list CblasTrans CblasNoTrans)
  278. (list CblasTrans CblasNoTrans)))))
  279. `((f32 ,cblas-sgemm!)
  280. (f64 ,cblas-dgemm!)
  281. (c32 ,cblas-cgemm!)
  282. (c64 ,cblas-zgemm!)))
  283. (define error-count (test-runner-fail-count (test-runner-current)))
  284. (test-end "ffi-cblas")
  285. (exit error-count)