bench-gemm.cc 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file bench-gemm.hh
  3. /// @brief Benchmark for BLAS-3 type ops
  4. // (c) Daniel Llorens - 2016-2017
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. // These operations aren't really part of the ET framework, just standalone functions.
  10. // Cf bench-gemv.cc for BLAS-2 type ops.
  11. // FIXME Benchmark w/o allocation.
  12. #include <iostream>
  13. #include <iomanip>
  14. #include "ra/test.hh"
  15. #include "ra/complex.hh"
  16. #include "ra/ra.hh"
  17. #include "ra/bench.hh"
  18. using std::cout, std::endl, std::setw, std::setprecision, ra::TestRecorder;
  19. using ra::Small, ra::View, ra::Unique, ra::dim_t;
  20. using real = double;
  21. // -------------------
  22. // variants of the defaults, should be slower if the default is well picked.
  23. // -------------------
  24. template <class A, class B, class C> inline void
  25. gemm_block_3(ra::View<A, 2> const & a, ra::View<B, 2> const & b, ra::View<C, 2> c)
  26. {
  27. dim_t const m = a.size(0);
  28. dim_t const p = a.size(1);
  29. dim_t const n = b.size(1);
  30. // terminal, using reduce_k, see below
  31. if (max(m, max(p, n))<=64) {
  32. for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })),
  33. c, a, b);
  34. // split a's rows
  35. } else if (m>=max(p, n)) {
  36. gemm_block_3(a(ra::iota(m/2)), b, c(ra::iota(m/2)));
  37. gemm_block_3(a(ra::iota(m-m/2, m/2)), b, c(ra::iota(m-m/2, m/2)));
  38. // split b's columns
  39. } else if (n>=max(m, p)) {
  40. gemm_block_3(a, b(ra::all, ra::iota(n/2)), c(ra::all, ra::iota(n/2)));
  41. gemm_block_3(a, b(ra::all, ra::iota(n-n/2, n/2)), c(ra::all, ra::iota(n-n/2, n/2)));
  42. // split a's columns and b's rows
  43. } else {
  44. gemm_block_3(a(ra::all, ra::iota(p/2)), b(ra::iota(p/2)), c);
  45. gemm_block_3(a(ra::all, ra::iota(p-p/2, p/2)), b(ra::iota(p-p/2, p/2)), c);
  46. }
  47. }
  48. #if RA_USE_BLAS==1
  49. extern "C" {
  50. #include <cblas.h>
  51. }
  52. inline constexpr CBLAS_TRANSPOSE fliptr(CBLAS_TRANSPOSE t)
  53. {
  54. if (t==CblasTrans) {
  55. return CblasNoTrans;
  56. } else if (t==CblasNoTrans) {
  57. return CblasTrans;
  58. } else {
  59. assert(0 && "BLAS doesn't support this transpose");
  60. }
  61. }
  62. inline constexpr bool istr(CBLAS_TRANSPOSE t)
  63. {
  64. return (t==CblasTrans) || (t==CblasConjTrans);
  65. }
  66. template <class A> inline void
  67. lead_and_order(A const & a, int & ld, CBLAS_ORDER & order)
  68. {
  69. if (a.stride(1)==1) {
  70. order = CblasRowMajor;
  71. ld = a.stride(0);
  72. } else if (a.stride(0)==1) {
  73. order = CblasColMajor;
  74. ld = a.stride(1);
  75. } else {
  76. assert(0 && "not a BLAS-supported array");
  77. }
  78. }
  79. inline void
  80. gemm_blas_3(ra::View<double, 2> const & A, ra::View<double, 2> const & B, ra::View<double, 2> C)
  81. {
  82. CBLAS_TRANSPOSE ta = CblasNoTrans;
  83. CBLAS_TRANSPOSE tb = CblasNoTrans;
  84. int ldc, lda, ldb;
  85. CBLAS_ORDER orderc, ordera, orderb;
  86. lead_and_order(C, ldc, orderc);
  87. lead_and_order(A, lda, ordera);
  88. lead_and_order(B, ldb, orderb);
  89. int K = A.size(1-istr(ta));
  90. assert(K==B.size(istr(tb)) && "mismatched A/B");
  91. assert(C.size(0)==A.size(istr(ta)) && "mismatched C/A");
  92. assert(C.size(1)==B.size(1-istr(tb)) && "mismatched C/B");
  93. if (ordera!=orderc) {
  94. ta = fliptr(ta);
  95. }
  96. if (orderb!=orderc) {
  97. tb = fliptr(tb);
  98. }
  99. if (C.size()>0) {
  100. cblas_dgemm(orderc, ta, tb, C.size(0), C.size(1), K, 1., A.data(), lda, B.data(), ldb, 0, C.data(), ldc);
  101. }
  102. }
  103. inline auto
  104. gemm_blas(ra::View<double, 2> const & a, ra::View<double, 2> const & b)
  105. {
  106. ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.size(0), b.size(1)}, 0);
  107. gemm_blas_3(a, b, c);
  108. return c;
  109. }
  110. #endif // RA_USE_BLAS
  111. int main()
  112. {
  113. TestRecorder tr(std::cout);
  114. auto gemm_block = [&](auto const & a, auto const & b)
  115. {
  116. ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.size(0), b.size(1)}, 0);
  117. gemm_block_3(a, b, c);
  118. return c;
  119. };
  120. auto gemm_k = [&](auto const & a, auto const & b)
  121. {
  122. dim_t const M = a.size(0);
  123. dim_t const N = b.size(1);
  124. ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({M, N}, ra::none);
  125. for (dim_t i=0; i<M; ++i) {
  126. for (dim_t j=0; j<N; ++j) {
  127. c(i, j) = dot(a(i), b(ra::all, j));
  128. }
  129. }
  130. return c;
  131. };
  132. // See test/wrank.cc "outer product variants" for the logic.
  133. // TODO based on this, allow a Blitz++ like notation C(i, j) = sum(A(i, k)*B(k, j), k). Maybe using TensorIndex now that that works with ply_ravel.
  134. auto gemm_reduce_k = [&](auto const & a, auto const & b)
  135. {
  136. dim_t const M = a.size(0);
  137. dim_t const N = b.size(1);
  138. using T = decltype(a(0, 0)*b(0, 0));
  139. ra::Big<T, 2> c({M, N}, T());
  140. for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })),
  141. c, a, b);
  142. return c;
  143. };
  144. #define DEFINE_GEMM_RESTRICT(NAME_K, NAME_IJ, RESTRICT) \
  145. auto NAME_K = [&](auto const & a, auto const & b) \
  146. { \
  147. dim_t const M = a.size(0); \
  148. dim_t const N = b.size(1); \
  149. dim_t const K = a.size(1); \
  150. using T = decltype(a(0, 0)*b(0, 0)); \
  151. ra::Big<T, 2> c({M, N}, T()); \
  152. T * RESTRICT cc = c.data(); \
  153. T const * RESTRICT aa = a.data(); \
  154. T const * RESTRICT bb = b.data(); \
  155. for (dim_t i=0; i<M; ++i) { \
  156. for (dim_t j=0; j<N; ++j) { \
  157. for (dim_t k=0; k<K; ++k) { \
  158. cc[i*N+j] += aa[i*K+k] * bb[k*N+j]; \
  159. } \
  160. } \
  161. } \
  162. return c; \
  163. }; \
  164. \
  165. auto NAME_IJ = [&](auto const & a, auto const & b) \
  166. { \
  167. dim_t const M = a.size(0); \
  168. dim_t const N = b.size(1); \
  169. dim_t const K = a.size(1); \
  170. using T = decltype(a(0, 0)*b(0, 0)); \
  171. ra::Big<T, 2> c({M, N}, T()); \
  172. T * RESTRICT cc = c.data(); \
  173. T const * RESTRICT aa = a.data(); \
  174. T const * RESTRICT bb = b.data(); \
  175. for (dim_t k=0; k<K; ++k) { \
  176. for (dim_t i=0; i<M; ++i) { \
  177. for (dim_t j=0; j<N; ++j) { \
  178. cc[i*N+j] += aa[i*K+k] * bb[k*N+j]; \
  179. } \
  180. } \
  181. } \
  182. return c; \
  183. };
  184. DEFINE_GEMM_RESTRICT(gemm_k_raw, gemm_ij_raw, /* */)
  185. DEFINE_GEMM_RESTRICT(gemm_k_raw_restrict, gemm_ij_raw_restrict, __restrict__)
  186. #undef DEFINE_GEMM_RESTRICT
  187. auto bench_all = [&](int k, int m, int p, int n, int reps)
  188. {
  189. auto bench = [&](auto && f, char const * tag)
  190. {
  191. ra::Big<real, 2> a({m, p}, ra::_0-ra::_1);
  192. ra::Big<real, 2> b({p, n}, ra::_1-2*ra::_0);
  193. ra::Big<real, 2> ref = gemm(a, b);
  194. ra::Big<real, 2> c;
  195. auto bv = Benchmark().repeats(reps).runs(3).run([&]() { c = f(a, b); });
  196. tr.info(std::setw(5), std::fixed, Benchmark::avg(bv)/(m*n*p)/1e-9, " ns [",
  197. Benchmark::stddev(bv)/(m*n*p)/1e-9 ,"] ", tag).test_eq(ref, c);
  198. };
  199. tr.section(m, " (", p, ") ", n, " times ", reps);
  200. // some variants are way too slow to check with larger arrays.
  201. if (k>2) {
  202. bench(gemm_k, "k");
  203. }
  204. if (k>1) {
  205. bench(gemm_k_raw, "k_raw");
  206. bench(gemm_k_raw_restrict, "k_raw_restrict");
  207. }
  208. if (k>0) {
  209. bench(gemm_reduce_k, "reduce_k");
  210. bench(gemm_ij_raw, "ij_raw");
  211. bench(gemm_ij_raw_restrict, "ij_raw_restrict");
  212. }
  213. bench(gemm_block, "block");
  214. #if RA_USE_BLAS==1
  215. bench(gemm_blas, "blas");
  216. #endif
  217. bench([&](auto const & a, auto const & b) { return gemm(a, b); }, "default");
  218. };
  219. bench_all(3, 10, 10, 10, 10000);
  220. bench_all(2, 100, 100, 100, 100);
  221. bench_all(2, 500, 400, 500, 1);
  222. bench_all(1, 10000, 10, 1000, 1);
  223. bench_all(1, 1000, 10, 10000, 1);
  224. bench_all(1, 100000, 10, 100, 1);
  225. bench_all(1, 100, 10, 100000, 1);
  226. return tr.summary();
  227. }