bench-gemm.cc 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra/bench - BLAS-3 type ops.
  3. // (c) Daniel Llorens - 2016-2017
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // These operations aren't really part of the ET framework, just standalone functions.
  9. // Cf bench-gemv.cc for BLAS-2 type ops.
  10. // FIXME Benchmark w/o allocation.
  11. #include <iostream>
  12. #include <iomanip>
  13. #include "ra/bench.hh"
  14. using std::cout, std::endl, std::setw, std::setprecision, ra::TestRecorder;
  15. using ra::Small, ra::ViewBig, ra::Unique, ra::dim_t, ra::all;
  16. using real = double;
  17. void
  18. gemm1(auto && a, auto && b, auto & c)
  19. {
  20. for_each(ra::wrank<1, 2, 1>(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto & c) { ra::maybe_fma(a, b, c); })),
  21. RA_FWD(a), RA_FWD(b), RA_FWD(c));
  22. }
  23. void
  24. gemm2(auto && a, auto && b, auto & c)
  25. {
  26. dim_t K=a.len(1);
  27. for (int k=0; k<K; ++k) {
  28. c += from(std::multiplies<>(), a(all, k), b(k)); // FIXME fma
  29. }
  30. }
  31. void
  32. gemm3(auto && a, auto && b, auto & c)
  33. {
  34. dim_t K=a.len(1);
  35. for (int k=0; k<K; ++k) {
  36. for_each(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto & c) { ra::maybe_fma(a, b, c); }), a(all, k), b(k), c);
  37. }
  38. }
  39. void
  40. gemm4(auto && a, auto && b, auto & c)
  41. {
  42. dim_t M=a.len(0), N=b.len(1);
  43. for (int i=0; i<M; ++i) {
  44. for (int j=0; j<N; ++j) {
  45. c(i, j) = dot(a(i), b(all, j));
  46. }
  47. }
  48. }
  49. // -------------------
  50. // variants of the defaults, should be slower if the default is well picked.
  51. // -------------------
  52. template <class A, class B, class C> inline void
  53. gemm_block(ra::ViewBig<A, 2> const & a, ra::ViewBig<B, 2> const & b, ra::ViewBig<C, 2> c)
  54. {
  55. dim_t const m = a.len(0);
  56. dim_t const p = a.len(1);
  57. dim_t const n = b.len(1);
  58. // terminal, using reduce_k, see below
  59. if (max(m, max(p, n))<=64) {
  60. gemm(a, b, c);
  61. // split a's rows
  62. } else if (m>=max(p, n)) {
  63. gemm_block(a(ra::iota(m/2)), b, c(ra::iota(m/2)));
  64. gemm_block(a(ra::iota(m-m/2, m/2)), b, c(ra::iota(m-m/2, m/2)));
  65. // split b's columns
  66. } else if (n>=max(m, p)) {
  67. gemm_block(a, b(all, ra::iota(n/2)), c(all, ra::iota(n/2)));
  68. gemm_block(a, b(all, ra::iota(n-n/2, n/2)), c(all, ra::iota(n-n/2, n/2)));
  69. // split a's columns and b's rows
  70. } else {
  71. gemm_block(a(all, ra::iota(p/2)), b(ra::iota(p/2)), c);
  72. gemm_block(a(all, ra::iota(p-p/2, p/2)), b(ra::iota(p-p/2, p/2)), c);
  73. }
  74. }
  75. #if RA_USE_BLAS==1
  76. extern "C" {
  77. #include <cblas.h>
  78. }
  79. constexpr CBLAS_TRANSPOSE
  80. fliptr(CBLAS_TRANSPOSE t)
  81. {
  82. if (t==CblasTrans) {
  83. return CblasNoTrans;
  84. } else if (t==CblasNoTrans) {
  85. return CblasTrans;
  86. } else {
  87. assert(0 && "BLAS doesn't support this transpose");
  88. abort();
  89. }
  90. }
  91. constexpr bool
  92. istr(CBLAS_TRANSPOSE t)
  93. {
  94. return (t==CblasTrans) || (t==CblasConjTrans);
  95. }
  96. template <class A> inline void
  97. lead_and_order(A const & a, int & ld, CBLAS_ORDER & order)
  98. {
  99. if (a.step(1)==1) {
  100. order = CblasRowMajor;
  101. ld = a.step(0);
  102. } else if (a.step(0)==1) {
  103. order = CblasColMajor;
  104. ld = a.step(1);
  105. } else {
  106. order = CblasRowMajor;
  107. ld = 0;
  108. assert(0 && "not a BLAS-supported array");
  109. }
  110. }
  111. inline void
  112. gemm_blas(ra::ViewBig<double, 2> const & A, ra::ViewBig<double, 2> const & B, ra::ViewBig<double, 2> C)
  113. {
  114. CBLAS_TRANSPOSE ta = CblasNoTrans;
  115. CBLAS_TRANSPOSE tb = CblasNoTrans;
  116. int ldc, lda, ldb;
  117. CBLAS_ORDER orderc, ordera, orderb;
  118. lead_and_order(C, ldc, orderc);
  119. lead_and_order(A, lda, ordera);
  120. lead_and_order(B, ldb, orderb);
  121. int K = A.len(1-istr(ta));
  122. assert(K==B.len(istr(tb)) && "mismatched A/B");
  123. assert(C.len(0)==A.len(istr(ta)) && "mismatched C/A");
  124. assert(C.len(1)==B.len(1-istr(tb)) && "mismatched C/B");
  125. if (ordera!=orderc) {
  126. ta = fliptr(ta);
  127. }
  128. if (orderb!=orderc) {
  129. tb = fliptr(tb);
  130. }
  131. if (C.size()>0) {
  132. cblas_dgemm(orderc, ta, tb, C.len(0), C.len(1), K, 1., A.data(), lda, B.data(), ldb, 0, C.data(), ldc);
  133. }
  134. }
  135. #endif // RA_USE_BLAS
  136. int main()
  137. {
  138. TestRecorder tr(std::cout);
  139. cout << "RA_DO_FMA is " << RA_DO_FMA << endl;
  140. auto gemm_k = [&](auto const & a, auto const & b, auto & c)
  141. {
  142. dim_t const M = a.len(0);
  143. dim_t const N = b.len(1);
  144. for (dim_t i=0; i<M; ++i) {
  145. for (dim_t j=0; j<N; ++j) {
  146. c(i, j) = dot(a(i), b(all, j));
  147. }
  148. }
  149. return c;
  150. };
  151. #define DEFINE_GEMM_RESTRICT(NAME_K, NAME_IJ, RESTRICT) \
  152. auto NAME_K = [&](auto const & a, auto const & b, auto & c) \
  153. { \
  154. dim_t const M = a.len(0); \
  155. dim_t const N = b.len(1); \
  156. dim_t const K = a.len(1); \
  157. auto * RESTRICT cc = c.data(); \
  158. auto const * RESTRICT aa = a.data(); \
  159. auto const * RESTRICT bb = b.data(); \
  160. for (dim_t i=0; i<M; ++i) { \
  161. for (dim_t j=0; j<N; ++j) { \
  162. for (dim_t k=0; k<K; ++k) { \
  163. cc[i*N+j] += aa[i*K+k] * bb[k*N+j]; \
  164. } \
  165. } \
  166. } \
  167. }; \
  168. auto NAME_IJ = [&](auto const & a, auto const & b, auto & c) \
  169. { \
  170. dim_t const M = a.len(0); \
  171. dim_t const N = b.len(1); \
  172. dim_t const K = a.len(1); \
  173. auto * RESTRICT cc = c.data(); \
  174. auto const * RESTRICT aa = a.data(); \
  175. auto const * RESTRICT bb = b.data(); \
  176. for (dim_t k=0; k<K; ++k) { \
  177. for (dim_t i=0; i<M; ++i) { \
  178. for (dim_t j=0; j<N; ++j) { \
  179. cc[i*N+j] += aa[i*K+k] * bb[k*N+j]; \
  180. } \
  181. } \
  182. } \
  183. };
  184. DEFINE_GEMM_RESTRICT(gemm_k_raw, gemm_ij_raw, /* */)
  185. DEFINE_GEMM_RESTRICT(gemm_k_raw_restrict, gemm_ij_raw_restrict, __restrict__)
  186. #undef DEFINE_GEMM_RESTRICT
  187. auto bench_all = [&](int k, int m, int p, int n, int reps)
  188. {
  189. auto bench = [&](auto && f, char const * tag)
  190. {
  191. ra::Big<real, 2> a({m, p}, ra::_0-ra::_1);
  192. ra::Big<real, 2> b({p, n}, ra::_1-2*ra::_0);
  193. ra::Big<real, 2> ref = gemm(a, b);
  194. ra::Big<real, 2> c({m, n}, 0.);
  195. auto bv = Benchmark().repeats(reps).runs(3).run([&]() { f(a, b, c); });
  196. tr.info(std::setw(5), std::fixed, Benchmark::avg(bv)/(m*n*p)/1e-9, " ns [",
  197. Benchmark::stddev(bv)/(m*n*p)/1e-9 ,"] ", tag).test_eq(ref, c);
  198. };
  199. tr.section(m, " (", p, ") ", n, " times ", reps);
  200. #define ZEROFIRST(GEMM) [&](auto const & a, auto const & b, auto & c) { c = 0; GEMM(a, b, c); }
  201. #define NOTZEROFIRST(GEMM) [&](auto const & a, auto const & b, auto & c) { GEMM(a, b, c); }
  202. // some variants are way too slow to check with larger arrays.
  203. if (k>2) {
  204. bench(NOTZEROFIRST(gemm_k), "k");
  205. }
  206. if (k>1) {
  207. bench(ZEROFIRST(gemm_k_raw), "k_raw");
  208. bench(ZEROFIRST(gemm_k_raw_restrict), "k_raw_restrict");
  209. }
  210. if (k>0) {
  211. bench(ZEROFIRST(gemm_ij_raw), "ij_raw");
  212. bench(ZEROFIRST(gemm_ij_raw_restrict), "ij_raw_restrict");
  213. }
  214. bench(ZEROFIRST(gemm_block), "block");
  215. bench(ZEROFIRST(gemm1), "gemm1");
  216. bench(ZEROFIRST(gemm2), "gemm2");
  217. bench(ZEROFIRST(gemm3), "gemm3");
  218. bench(ZEROFIRST(gemm4), "gemm4");
  219. #if RA_USE_BLAS==1
  220. bench(ZEROFIRST(gemm_blas), "blas");
  221. #endif
  222. bench(ZEROFIRST(gemm), "default");
  223. };
  224. bench_all(3, 4, 4, 4, 100);
  225. bench_all(3, 10, 10, 10, 10);
  226. bench_all(2, 100, 100, 100, 10);
  227. bench_all(2, 500, 400, 500, 1);
  228. bench_all(1, 10000, 10, 1000, 1);
  229. bench_all(1, 1000, 10, 10000, 1);
  230. bench_all(1, 100000, 10, 100, 1);
  231. bench_all(1, 100, 10, 100000, 1);
  232. return tr.summary();
  233. }