mm.c 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. // See LICENSE for license details.
  2. #include "common.h"
  3. #include <assert.h>
  4. #include <math.h>
  5. #include <stdint.h>
  6. #include <alloca.h>
  7. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  8. static void mm_naive(size_t m, size_t n, size_t p,
  9. t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
  10. {
  11. for (size_t i = 0; i < m; i++)
  12. {
  13. for (size_t j = 0; j < n; j++)
  14. {
  15. t s0 = c[i*ldc+j], s1 = 0, s2 = 0, s3 = 0;
  16. for (size_t k = 0; k < p/4*4; k+=4)
  17. {
  18. s0 = fma(a[i*lda+k+0], b[(k+0)*ldb+j], s0);
  19. s1 = fma(a[i*lda+k+1], b[(k+1)*ldb+j], s1);
  20. s2 = fma(a[i*lda+k+2], b[(k+2)*ldb+j], s2);
  21. s3 = fma(a[i*lda+k+3], b[(k+3)*ldb+j], s3);
  22. }
  23. for (size_t k = p/4*4; k < p; k++)
  24. s0 = fma(a[i*lda+k], b[k*ldb+j], s0);
  25. c[i*ldc+j] = (s0 + s1) + (s2 + s3);
  26. }
  27. }
  28. }
  29. static inline void mm_rb(size_t m, size_t n, size_t p,
  30. t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
  31. {
  32. size_t mb = m/RBM*RBM, nb = n/RBN*RBN;
  33. for (size_t i = 0; i < mb; i += RBM)
  34. {
  35. for (size_t j = 0; j < nb; j += RBN)
  36. kloop(p, a+i*lda, lda, b+j, ldb, c+i*ldc+j, ldc);
  37. mm_naive(RBM, n - nb, p, a+i*lda, lda, b+nb, ldb, c+i*ldc+nb, ldc);
  38. }
  39. mm_naive(m - mb, n, p, a+mb*lda, lda, b, ldb, c+mb*ldc, ldc);
  40. }
  41. static inline void repack(t* a, size_t lda, const t* a0, size_t lda0, size_t m, size_t p)
  42. {
  43. for (size_t i = 0; i < m; i++)
  44. {
  45. for (size_t j = 0; j < p/8*8; j+=8)
  46. {
  47. t t0 = a0[i*lda0+j+0];
  48. t t1 = a0[i*lda0+j+1];
  49. t t2 = a0[i*lda0+j+2];
  50. t t3 = a0[i*lda0+j+3];
  51. t t4 = a0[i*lda0+j+4];
  52. t t5 = a0[i*lda0+j+5];
  53. t t6 = a0[i*lda0+j+6];
  54. t t7 = a0[i*lda0+j+7];
  55. a[i*lda+j+0] = t0;
  56. a[i*lda+j+1] = t1;
  57. a[i*lda+j+2] = t2;
  58. a[i*lda+j+3] = t3;
  59. a[i*lda+j+4] = t4;
  60. a[i*lda+j+5] = t5;
  61. a[i*lda+j+6] = t6;
  62. a[i*lda+j+7] = t7;
  63. }
  64. for (size_t j = p/8*8; j < p; j++)
  65. a[i*lda+j] = a0[i*lda0+j];
  66. }
  67. }
  68. static void mm_cb(size_t m, size_t n, size_t p,
  69. t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
  70. {
  71. size_t nmb = m/CBM, nnb = n/CBN, npb = p/CBK;
  72. size_t mb = nmb*CBM, nb = nnb*CBN, pb = npb*CBK;
  73. //t a1[mb*pb], b1[pb*nb], c1[mb*nb];
  74. t* a1 = (t*)alloca_aligned(sizeof(t)*mb*pb, 8192);
  75. t* b1 = (t*)alloca_aligned(sizeof(t)*pb*nb, 8192);
  76. t* c1 = (t*)alloca_aligned(sizeof(t)*mb*nb, 8192);
  77. for (size_t i = 0; i < mb; i += CBM)
  78. for (size_t j = 0; j < pb; j += CBK)
  79. repack(a1 + (npb*(i/CBM) + j/CBK)*(CBM*CBK), CBK, a + i*lda + j, lda, CBM, CBK);
  80. for (size_t i = 0; i < pb; i += CBK)
  81. for (size_t j = 0; j < nb; j += CBN)
  82. repack(b1 + (nnb*(i/CBK) + j/CBN)*(CBK*CBN), CBN, b + i*ldb + j, ldb, CBK, CBN);
  83. for (size_t i = 0; i < mb; i += CBM)
  84. for (size_t j = 0; j < nb; j += CBN)
  85. repack(c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN, c + i*ldc + j, ldc, CBM, CBN);
  86. for (size_t i = 0; i < mb; i += CBM)
  87. {
  88. for (size_t j = 0; j < nb; j += CBN)
  89. {
  90. for (size_t k = 0; k < pb; k += CBK)
  91. {
  92. mm_rb(CBM, CBN, CBK,
  93. a1 + (npb*(i/CBM) + k/CBK)*(CBM*CBK), CBK,
  94. b1 + (nnb*(k/CBK) + j/CBN)*(CBK*CBN), CBN,
  95. c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN);
  96. }
  97. if (pb < p)
  98. {
  99. mm_rb(CBM, CBN, p - pb,
  100. a + i*lda + pb, lda,
  101. b + pb*ldb + j, ldb,
  102. c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN);
  103. }
  104. }
  105. if (nb < n)
  106. {
  107. for (size_t k = 0; k < p; k += CBK)
  108. {
  109. mm_rb(CBM, n - nb, MIN(p - k, CBK),
  110. a + i*lda + k, lda,
  111. b + k*ldb + nb, ldb,
  112. c + i*ldc + nb, ldc);
  113. }
  114. }
  115. }
  116. if (mb < m)
  117. {
  118. for (size_t j = 0; j < n; j += CBN)
  119. {
  120. for (size_t k = 0; k < p; k += CBK)
  121. {
  122. mm_rb(m - mb, MIN(n - j, CBN), MIN(p - k, CBK),
  123. a + mb*lda + k, lda,
  124. b + k*ldb + j, ldb,
  125. c + mb*ldc + j, ldc);
  126. }
  127. }
  128. }
  129. for (size_t i = 0; i < mb; i += CBM)
  130. for (size_t j = 0; j < nb; j += CBN)
  131. repack(c + i*ldc + j, ldc, c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN, CBM, CBN);
  132. }
  133. void mm(size_t m, size_t n, size_t p,
  134. t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
  135. {
  136. if (__builtin_expect(m <= 2*CBM && n <= 2*CBN && p <= 2*CBK, 1))
  137. mm_rb(m, n, p, a, lda, b, ldb, c, ldc);
  138. else
  139. mm_cb(m, n, p, a, lda, b, ldb, c, ldc);
  140. }