matmul_i2.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. /* Implementation of the MATMUL intrinsic
  2. Copyright (C) 2002-2015 Free Software Foundation, Inc.
  3. Contributed by Paul Brook <paul@nowt.org>
  4. This file is part of the GNU Fortran runtime library (libgfortran).
  5. Libgfortran is free software; you can redistribute it and/or
  6. modify it under the terms of the GNU General Public
  7. License as published by the Free Software Foundation; either
  8. version 3 of the License, or (at your option) any later version.
  9. Libgfortran is distributed in the hope that it will be useful,
  10. but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. GNU General Public License for more details.
  13. Under Section 7 of GPL version 3, you are granted additional
  14. permissions described in the GCC Runtime Library Exception, version
  15. 3.1, as published by the Free Software Foundation.
  16. You should have received a copy of the GNU General Public License and
  17. a copy of the GCC Runtime Library Exception along with this program;
  18. see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
  19. <http://www.gnu.org/licenses/>. */
  20. #include "libgfortran.h"
  21. #include <stdlib.h>
  22. #include <string.h>
  23. #include <assert.h>
  24. #if defined (HAVE_GFC_INTEGER_2)
  25. /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
  26. passed to us by the front-end, in which case we'll call it for large
  27. matrices. */
  28. typedef void (*blas_call)(const char *, const char *, const int *, const int *,
  29. const int *, const GFC_INTEGER_2 *, const GFC_INTEGER_2 *,
  30. const int *, const GFC_INTEGER_2 *, const int *,
  31. const GFC_INTEGER_2 *, GFC_INTEGER_2 *, const int *,
  32. int, int);
  33. /* The order of loops is different in the case of plain matrix
  34. multiplication C=MATMUL(A,B), and in the frequent special case where
  35. the argument A is the temporary result of a TRANSPOSE intrinsic:
  36. C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
  37. looking at their strides.
  38. The equivalent Fortran pseudo-code is:
  39. DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
  40. IF (.NOT.IS_TRANSPOSED(A)) THEN
  41. C = 0
  42. DO J=1,N
  43. DO K=1,COUNT
  44. DO I=1,M
  45. C(I,J) = C(I,J)+A(I,K)*B(K,J)
  46. ELSE
  47. DO J=1,N
  48. DO I=1,M
  49. S = 0
  50. DO K=1,COUNT
  51. S = S+A(I,K)*B(K,J)
  52. C(I,J) = S
  53. ENDIF
  54. */
  55. /* If try_blas is set to a nonzero value, then the matmul function will
  56. see if there is a way to perform the matrix multiplication by a call
  57. to the BLAS gemm function. */
  58. extern void matmul_i2 (gfc_array_i2 * const restrict retarray,
  59. gfc_array_i2 * const restrict a, gfc_array_i2 * const restrict b, int try_blas,
  60. int blas_limit, blas_call gemm);
  61. export_proto(matmul_i2);
  62. void
  63. matmul_i2 (gfc_array_i2 * const restrict retarray,
  64. gfc_array_i2 * const restrict a, gfc_array_i2 * const restrict b, int try_blas,
  65. int blas_limit, blas_call gemm)
  66. {
  67. const GFC_INTEGER_2 * restrict abase;
  68. const GFC_INTEGER_2 * restrict bbase;
  69. GFC_INTEGER_2 * restrict dest;
  70. index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
  71. index_type x, y, n, count, xcount, ycount;
  72. assert (GFC_DESCRIPTOR_RANK (a) == 2
  73. || GFC_DESCRIPTOR_RANK (b) == 2);
  74. /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
  75. Either A or B (but not both) can be rank 1:
  76. o One-dimensional argument A is implicitly treated as a row matrix
  77. dimensioned [1,count], so xcount=1.
  78. o One-dimensional argument B is implicitly treated as a column matrix
  79. dimensioned [count, 1], so ycount=1.
  80. */
  81. if (retarray->base_addr == NULL)
  82. {
  83. if (GFC_DESCRIPTOR_RANK (a) == 1)
  84. {
  85. GFC_DIMENSION_SET(retarray->dim[0], 0,
  86. GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
  87. }
  88. else if (GFC_DESCRIPTOR_RANK (b) == 1)
  89. {
  90. GFC_DIMENSION_SET(retarray->dim[0], 0,
  91. GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
  92. }
  93. else
  94. {
  95. GFC_DIMENSION_SET(retarray->dim[0], 0,
  96. GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
  97. GFC_DIMENSION_SET(retarray->dim[1], 0,
  98. GFC_DESCRIPTOR_EXTENT(b,1) - 1,
  99. GFC_DESCRIPTOR_EXTENT(retarray,0));
  100. }
  101. retarray->base_addr
  102. = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_2));
  103. retarray->offset = 0;
  104. }
  105. else if (unlikely (compile_options.bounds_check))
  106. {
  107. index_type ret_extent, arg_extent;
  108. if (GFC_DESCRIPTOR_RANK (a) == 1)
  109. {
  110. arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
  111. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  112. if (arg_extent != ret_extent)
  113. runtime_error ("Incorrect extent in return array in"
  114. " MATMUL intrinsic: is %ld, should be %ld",
  115. (long int) ret_extent, (long int) arg_extent);
  116. }
  117. else if (GFC_DESCRIPTOR_RANK (b) == 1)
  118. {
  119. arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
  120. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  121. if (arg_extent != ret_extent)
  122. runtime_error ("Incorrect extent in return array in"
  123. " MATMUL intrinsic: is %ld, should be %ld",
  124. (long int) ret_extent, (long int) arg_extent);
  125. }
  126. else
  127. {
  128. arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
  129. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  130. if (arg_extent != ret_extent)
  131. runtime_error ("Incorrect extent in return array in"
  132. " MATMUL intrinsic for dimension 1:"
  133. " is %ld, should be %ld",
  134. (long int) ret_extent, (long int) arg_extent);
  135. arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
  136. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
  137. if (arg_extent != ret_extent)
  138. runtime_error ("Incorrect extent in return array in"
  139. " MATMUL intrinsic for dimension 2:"
  140. " is %ld, should be %ld",
  141. (long int) ret_extent, (long int) arg_extent);
  142. }
  143. }
  144. if (GFC_DESCRIPTOR_RANK (retarray) == 1)
  145. {
  146. /* One-dimensional result may be addressed in the code below
  147. either as a row or a column matrix. We want both cases to
  148. work. */
  149. rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
  150. }
  151. else
  152. {
  153. rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
  154. rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
  155. }
  156. if (GFC_DESCRIPTOR_RANK (a) == 1)
  157. {
  158. /* Treat it as a a row matrix A[1,count]. */
  159. axstride = GFC_DESCRIPTOR_STRIDE(a,0);
  160. aystride = 1;
  161. xcount = 1;
  162. count = GFC_DESCRIPTOR_EXTENT(a,0);
  163. }
  164. else
  165. {
  166. axstride = GFC_DESCRIPTOR_STRIDE(a,0);
  167. aystride = GFC_DESCRIPTOR_STRIDE(a,1);
  168. count = GFC_DESCRIPTOR_EXTENT(a,1);
  169. xcount = GFC_DESCRIPTOR_EXTENT(a,0);
  170. }
  171. if (count != GFC_DESCRIPTOR_EXTENT(b,0))
  172. {
  173. if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
  174. runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
  175. }
  176. if (GFC_DESCRIPTOR_RANK (b) == 1)
  177. {
  178. /* Treat it as a column matrix B[count,1] */
  179. bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
  180. /* bystride should never be used for 1-dimensional b.
  181. in case it is we want it to cause a segfault, rather than
  182. an incorrect result. */
  183. bystride = 0xDEADBEEF;
  184. ycount = 1;
  185. }
  186. else
  187. {
  188. bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
  189. bystride = GFC_DESCRIPTOR_STRIDE(b,1);
  190. ycount = GFC_DESCRIPTOR_EXTENT(b,1);
  191. }
  192. abase = a->base_addr;
  193. bbase = b->base_addr;
  194. dest = retarray->base_addr;
  195. /* Now that everything is set up, we're performing the multiplication
  196. itself. */
  197. #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
  198. if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
  199. && (bxstride == 1 || bystride == 1)
  200. && (((float) xcount) * ((float) ycount) * ((float) count)
  201. > POW3(blas_limit)))
  202. {
  203. const int m = xcount, n = ycount, k = count, ldc = rystride;
  204. const GFC_INTEGER_2 one = 1, zero = 0;
  205. const int lda = (axstride == 1) ? aystride : axstride,
  206. ldb = (bxstride == 1) ? bystride : bxstride;
  207. if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
  208. {
  209. assert (gemm != NULL);
  210. gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
  211. &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
  212. return;
  213. }
  214. }
  215. if (rxstride == 1 && axstride == 1 && bxstride == 1)
  216. {
  217. const GFC_INTEGER_2 * restrict bbase_y;
  218. GFC_INTEGER_2 * restrict dest_y;
  219. const GFC_INTEGER_2 * restrict abase_n;
  220. GFC_INTEGER_2 bbase_yn;
  221. if (rystride == xcount)
  222. memset (dest, 0, (sizeof (GFC_INTEGER_2) * xcount * ycount));
  223. else
  224. {
  225. for (y = 0; y < ycount; y++)
  226. for (x = 0; x < xcount; x++)
  227. dest[x + y*rystride] = (GFC_INTEGER_2)0;
  228. }
  229. for (y = 0; y < ycount; y++)
  230. {
  231. bbase_y = bbase + y*bystride;
  232. dest_y = dest + y*rystride;
  233. for (n = 0; n < count; n++)
  234. {
  235. abase_n = abase + n*aystride;
  236. bbase_yn = bbase_y[n];
  237. for (x = 0; x < xcount; x++)
  238. {
  239. dest_y[x] += abase_n[x] * bbase_yn;
  240. }
  241. }
  242. }
  243. }
  244. else if (rxstride == 1 && aystride == 1 && bxstride == 1)
  245. {
  246. if (GFC_DESCRIPTOR_RANK (a) != 1)
  247. {
  248. const GFC_INTEGER_2 *restrict abase_x;
  249. const GFC_INTEGER_2 *restrict bbase_y;
  250. GFC_INTEGER_2 *restrict dest_y;
  251. GFC_INTEGER_2 s;
  252. for (y = 0; y < ycount; y++)
  253. {
  254. bbase_y = &bbase[y*bystride];
  255. dest_y = &dest[y*rystride];
  256. for (x = 0; x < xcount; x++)
  257. {
  258. abase_x = &abase[x*axstride];
  259. s = (GFC_INTEGER_2) 0;
  260. for (n = 0; n < count; n++)
  261. s += abase_x[n] * bbase_y[n];
  262. dest_y[x] = s;
  263. }
  264. }
  265. }
  266. else
  267. {
  268. const GFC_INTEGER_2 *restrict bbase_y;
  269. GFC_INTEGER_2 s;
  270. for (y = 0; y < ycount; y++)
  271. {
  272. bbase_y = &bbase[y*bystride];
  273. s = (GFC_INTEGER_2) 0;
  274. for (n = 0; n < count; n++)
  275. s += abase[n*axstride] * bbase_y[n];
  276. dest[y*rystride] = s;
  277. }
  278. }
  279. }
  280. else if (axstride < aystride)
  281. {
  282. for (y = 0; y < ycount; y++)
  283. for (x = 0; x < xcount; x++)
  284. dest[x*rxstride + y*rystride] = (GFC_INTEGER_2)0;
  285. for (y = 0; y < ycount; y++)
  286. for (n = 0; n < count; n++)
  287. for (x = 0; x < xcount; x++)
  288. /* dest[x,y] += a[x,n] * b[n,y] */
  289. dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
  290. }
  291. else if (GFC_DESCRIPTOR_RANK (a) == 1)
  292. {
  293. const GFC_INTEGER_2 *restrict bbase_y;
  294. GFC_INTEGER_2 s;
  295. for (y = 0; y < ycount; y++)
  296. {
  297. bbase_y = &bbase[y*bystride];
  298. s = (GFC_INTEGER_2) 0;
  299. for (n = 0; n < count; n++)
  300. s += abase[n*axstride] * bbase_y[n*bxstride];
  301. dest[y*rxstride] = s;
  302. }
  303. }
  304. else
  305. {
  306. const GFC_INTEGER_2 *restrict abase_x;
  307. const GFC_INTEGER_2 *restrict bbase_y;
  308. GFC_INTEGER_2 *restrict dest_y;
  309. GFC_INTEGER_2 s;
  310. for (y = 0; y < ycount; y++)
  311. {
  312. bbase_y = &bbase[y*bystride];
  313. dest_y = &dest[y*rystride];
  314. for (x = 0; x < xcount; x++)
  315. {
  316. abase_x = &abase[x*axstride];
  317. s = (GFC_INTEGER_2) 0;
  318. for (n = 0; n < count; n++)
  319. s += abase_x[n*aystride] * bbase_y[n*bxstride];
  320. dest_y[x*rxstride] = s;
  321. }
  322. }
  323. }
  324. }
  325. #endif