matmul.m4 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. `/* Implementation of the MATMUL intrinsic
  2. Copyright (C) 2002-2015 Free Software Foundation, Inc.
  3. Contributed by Paul Brook <paul@nowt.org>
  4. This file is part of the GNU Fortran runtime library (libgfortran).
  5. Libgfortran is free software; you can redistribute it and/or
  6. modify it under the terms of the GNU General Public
  7. License as published by the Free Software Foundation; either
  8. version 3 of the License, or (at your option) any later version.
  9. Libgfortran is distributed in the hope that it will be useful,
  10. but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. GNU General Public License for more details.
  13. Under Section 7 of GPL version 3, you are granted additional
  14. permissions described in the GCC Runtime Library Exception, version
  15. 3.1, as published by the Free Software Foundation.
  16. You should have received a copy of the GNU General Public License and
  17. a copy of the GCC Runtime Library Exception along with this program;
  18. see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
  19. <http://www.gnu.org/licenses/>. */
  20. #include "libgfortran.h"
  21. #include <stdlib.h>
  22. #include <string.h>
  23. #include <assert.h>'
  24. include(iparm.m4)dnl
  25. `#if defined (HAVE_'rtype_name`)
  26. /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
  27. passed to us by the front-end, in which case we''`ll call it for large
  28. matrices. */
  29. typedef void (*blas_call)(const char *, const char *, const int *, const int *,
  30. const int *, const 'rtype_name` *, const 'rtype_name` *,
  31. const int *, const 'rtype_name` *, const int *,
  32. const 'rtype_name` *, 'rtype_name` *, const int *,
  33. int, int);
  34. /* The order of loops is different in the case of plain matrix
  35. multiplication C=MATMUL(A,B), and in the frequent special case where
  36. the argument A is the temporary result of a TRANSPOSE intrinsic:
  37. C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
  38. looking at their strides.
  39. The equivalent Fortran pseudo-code is:
  40. DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
  41. IF (.NOT.IS_TRANSPOSED(A)) THEN
  42. C = 0
  43. DO J=1,N
  44. DO K=1,COUNT
  45. DO I=1,M
  46. C(I,J) = C(I,J)+A(I,K)*B(K,J)
  47. ELSE
  48. DO J=1,N
  49. DO I=1,M
  50. S = 0
  51. DO K=1,COUNT
  52. S = S+A(I,K)*B(K,J)
  53. C(I,J) = S
  54. ENDIF
  55. */
  56. /* If try_blas is set to a nonzero value, then the matmul function will
  57. see if there is a way to perform the matrix multiplication by a call
  58. to the BLAS gemm function. */
  59. extern void matmul_'rtype_code` ('rtype` * const restrict retarray,
  60. 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
  61. int blas_limit, blas_call gemm);
  62. export_proto(matmul_'rtype_code`);
  63. void
  64. matmul_'rtype_code` ('rtype` * const restrict retarray,
  65. 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
  66. int blas_limit, blas_call gemm)
  67. {
  68. const 'rtype_name` * restrict abase;
  69. const 'rtype_name` * restrict bbase;
  70. 'rtype_name` * restrict dest;
  71. index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
  72. index_type x, y, n, count, xcount, ycount;
  73. assert (GFC_DESCRIPTOR_RANK (a) == 2
  74. || GFC_DESCRIPTOR_RANK (b) == 2);
  75. /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
  76. Either A or B (but not both) can be rank 1:
  77. o One-dimensional argument A is implicitly treated as a row matrix
  78. dimensioned [1,count], so xcount=1.
  79. o One-dimensional argument B is implicitly treated as a column matrix
  80. dimensioned [count, 1], so ycount=1.
  81. */
  82. if (retarray->base_addr == NULL)
  83. {
  84. if (GFC_DESCRIPTOR_RANK (a) == 1)
  85. {
  86. GFC_DIMENSION_SET(retarray->dim[0], 0,
  87. GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
  88. }
  89. else if (GFC_DESCRIPTOR_RANK (b) == 1)
  90. {
  91. GFC_DIMENSION_SET(retarray->dim[0], 0,
  92. GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
  93. }
  94. else
  95. {
  96. GFC_DIMENSION_SET(retarray->dim[0], 0,
  97. GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
  98. GFC_DIMENSION_SET(retarray->dim[1], 0,
  99. GFC_DESCRIPTOR_EXTENT(b,1) - 1,
  100. GFC_DESCRIPTOR_EXTENT(retarray,0));
  101. }
  102. retarray->base_addr
  103. = xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`));
  104. retarray->offset = 0;
  105. }
  106. else if (unlikely (compile_options.bounds_check))
  107. {
  108. index_type ret_extent, arg_extent;
  109. if (GFC_DESCRIPTOR_RANK (a) == 1)
  110. {
  111. arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
  112. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  113. if (arg_extent != ret_extent)
  114. runtime_error ("Incorrect extent in return array in"
  115. " MATMUL intrinsic: is %ld, should be %ld",
  116. (long int) ret_extent, (long int) arg_extent);
  117. }
  118. else if (GFC_DESCRIPTOR_RANK (b) == 1)
  119. {
  120. arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
  121. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  122. if (arg_extent != ret_extent)
  123. runtime_error ("Incorrect extent in return array in"
  124. " MATMUL intrinsic: is %ld, should be %ld",
  125. (long int) ret_extent, (long int) arg_extent);
  126. }
  127. else
  128. {
  129. arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
  130. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  131. if (arg_extent != ret_extent)
  132. runtime_error ("Incorrect extent in return array in"
  133. " MATMUL intrinsic for dimension 1:"
  134. " is %ld, should be %ld",
  135. (long int) ret_extent, (long int) arg_extent);
  136. arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
  137. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
  138. if (arg_extent != ret_extent)
  139. runtime_error ("Incorrect extent in return array in"
  140. " MATMUL intrinsic for dimension 2:"
  141. " is %ld, should be %ld",
  142. (long int) ret_extent, (long int) arg_extent);
  143. }
  144. }
  145. '
  146. sinclude(`matmul_asm_'rtype_code`.m4')dnl
  147. `
  148. if (GFC_DESCRIPTOR_RANK (retarray) == 1)
  149. {
  150. /* One-dimensional result may be addressed in the code below
  151. either as a row or a column matrix. We want both cases to
  152. work. */
  153. rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
  154. }
  155. else
  156. {
  157. rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
  158. rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
  159. }
  160. if (GFC_DESCRIPTOR_RANK (a) == 1)
  161. {
  162. /* Treat it as a a row matrix A[1,count]. */
  163. axstride = GFC_DESCRIPTOR_STRIDE(a,0);
  164. aystride = 1;
  165. xcount = 1;
  166. count = GFC_DESCRIPTOR_EXTENT(a,0);
  167. }
  168. else
  169. {
  170. axstride = GFC_DESCRIPTOR_STRIDE(a,0);
  171. aystride = GFC_DESCRIPTOR_STRIDE(a,1);
  172. count = GFC_DESCRIPTOR_EXTENT(a,1);
  173. xcount = GFC_DESCRIPTOR_EXTENT(a,0);
  174. }
  175. if (count != GFC_DESCRIPTOR_EXTENT(b,0))
  176. {
  177. if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
  178. runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
  179. }
  180. if (GFC_DESCRIPTOR_RANK (b) == 1)
  181. {
  182. /* Treat it as a column matrix B[count,1] */
  183. bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
  184. /* bystride should never be used for 1-dimensional b.
  185. in case it is we want it to cause a segfault, rather than
  186. an incorrect result. */
  187. bystride = 0xDEADBEEF;
  188. ycount = 1;
  189. }
  190. else
  191. {
  192. bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
  193. bystride = GFC_DESCRIPTOR_STRIDE(b,1);
  194. ycount = GFC_DESCRIPTOR_EXTENT(b,1);
  195. }
  196. abase = a->base_addr;
  197. bbase = b->base_addr;
  198. dest = retarray->base_addr;
  199. /* Now that everything is set up, we''`re performing the multiplication
  200. itself. */
  201. #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
  202. if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
  203. && (bxstride == 1 || bystride == 1)
  204. && (((float) xcount) * ((float) ycount) * ((float) count)
  205. > POW3(blas_limit)))
  206. {
  207. const int m = xcount, n = ycount, k = count, ldc = rystride;
  208. const 'rtype_name` one = 1, zero = 0;
  209. const int lda = (axstride == 1) ? aystride : axstride,
  210. ldb = (bxstride == 1) ? bystride : bxstride;
  211. if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
  212. {
  213. assert (gemm != NULL);
  214. gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
  215. &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
  216. return;
  217. }
  218. }
  219. if (rxstride == 1 && axstride == 1 && bxstride == 1)
  220. {
  221. const 'rtype_name` * restrict bbase_y;
  222. 'rtype_name` * restrict dest_y;
  223. const 'rtype_name` * restrict abase_n;
  224. 'rtype_name` bbase_yn;
  225. if (rystride == xcount)
  226. memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
  227. else
  228. {
  229. for (y = 0; y < ycount; y++)
  230. for (x = 0; x < xcount; x++)
  231. dest[x + y*rystride] = ('rtype_name`)0;
  232. }
  233. for (y = 0; y < ycount; y++)
  234. {
  235. bbase_y = bbase + y*bystride;
  236. dest_y = dest + y*rystride;
  237. for (n = 0; n < count; n++)
  238. {
  239. abase_n = abase + n*aystride;
  240. bbase_yn = bbase_y[n];
  241. for (x = 0; x < xcount; x++)
  242. {
  243. dest_y[x] += abase_n[x] * bbase_yn;
  244. }
  245. }
  246. }
  247. }
  248. else if (rxstride == 1 && aystride == 1 && bxstride == 1)
  249. {
  250. if (GFC_DESCRIPTOR_RANK (a) != 1)
  251. {
  252. const 'rtype_name` *restrict abase_x;
  253. const 'rtype_name` *restrict bbase_y;
  254. 'rtype_name` *restrict dest_y;
  255. 'rtype_name` s;
  256. for (y = 0; y < ycount; y++)
  257. {
  258. bbase_y = &bbase[y*bystride];
  259. dest_y = &dest[y*rystride];
  260. for (x = 0; x < xcount; x++)
  261. {
  262. abase_x = &abase[x*axstride];
  263. s = ('rtype_name`) 0;
  264. for (n = 0; n < count; n++)
  265. s += abase_x[n] * bbase_y[n];
  266. dest_y[x] = s;
  267. }
  268. }
  269. }
  270. else
  271. {
  272. const 'rtype_name` *restrict bbase_y;
  273. 'rtype_name` s;
  274. for (y = 0; y < ycount; y++)
  275. {
  276. bbase_y = &bbase[y*bystride];
  277. s = ('rtype_name`) 0;
  278. for (n = 0; n < count; n++)
  279. s += abase[n*axstride] * bbase_y[n];
  280. dest[y*rystride] = s;
  281. }
  282. }
  283. }
  284. else if (axstride < aystride)
  285. {
  286. for (y = 0; y < ycount; y++)
  287. for (x = 0; x < xcount; x++)
  288. dest[x*rxstride + y*rystride] = ('rtype_name`)0;
  289. for (y = 0; y < ycount; y++)
  290. for (n = 0; n < count; n++)
  291. for (x = 0; x < xcount; x++)
  292. /* dest[x,y] += a[x,n] * b[n,y] */
  293. dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
  294. }
  295. else if (GFC_DESCRIPTOR_RANK (a) == 1)
  296. {
  297. const 'rtype_name` *restrict bbase_y;
  298. 'rtype_name` s;
  299. for (y = 0; y < ycount; y++)
  300. {
  301. bbase_y = &bbase[y*bystride];
  302. s = ('rtype_name`) 0;
  303. for (n = 0; n < count; n++)
  304. s += abase[n*axstride] * bbase_y[n*bxstride];
  305. dest[y*rxstride] = s;
  306. }
  307. }
  308. else
  309. {
  310. const 'rtype_name` *restrict abase_x;
  311. const 'rtype_name` *restrict bbase_y;
  312. 'rtype_name` *restrict dest_y;
  313. 'rtype_name` s;
  314. for (y = 0; y < ycount; y++)
  315. {
  316. bbase_y = &bbase[y*bystride];
  317. dest_y = &dest[y*rystride];
  318. for (x = 0; x < xcount; x++)
  319. {
  320. abase_x = &abase[x*axstride];
  321. s = ('rtype_name`) 0;
  322. for (n = 0; n < count; n++)
  323. s += abase_x[n*aystride] * bbase_y[n*bxstride];
  324. dest_y[x*rxstride] = s;
  325. }
  326. }
  327. }
  328. }
  329. #endif'