bench-gemm.cc 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra/bench - BLAS-3 type ops.
  3. // (c) Daniel Llorens - 2016-2017
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // These operations aren't really part of the ET framework, just standalone functions.
  9. // Cf bench-gemv.cc for BLAS-2 type ops.
  10. // FIXME Bench w/o allocation.
  11. // FIXME Bench offloading, e.g. RA_USE_BLAS=1 GOMP_DEBUG=0 CXXFLAGS="-O3 -fopenmp" LINKFLAGS="-fopenmp" scons -j6 -k bench/bench-gemm.test
  12. #include <iostream>
  13. #include <iomanip>
  14. #include "ra/test.hh"
  15. #include <omp.h>
  16. using std::cout, std::endl, std::setw, std::setprecision, ra::TestRecorder, ra::Benchmark;
  17. using ra::Small, ra::ViewBig, ra::Unique, ra::dim_t, ra::all;
  18. using real = double;
  19. void
  20. gemm1(auto && a, auto && b, auto & c)
  21. {
  22. for_each(ra::wrank<1, 2, 1>(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto & c) { ra::maybe_fma(a, b, c); })),
  23. RA_FWD(a), RA_FWD(b), RA_FWD(c));
  24. }
  25. void
  26. gemm2(auto && a, auto && b, auto & c)
  27. {
  28. dim_t K=a.len(1);
  29. for (int k=0; k<K; ++k) {
  30. c += from(std::multiplies<>(), a(all, k), b(k)); // FIXME fma
  31. }
  32. }
  33. void
  34. gemm3(auto && a, auto && b, auto & c)
  35. {
  36. dim_t K=a.len(1);
  37. for (int k=0; k<K; ++k) {
  38. for_each(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto & c) { ra::maybe_fma(a, b, c); }), a(all, k), b(k), c);
  39. }
  40. }
  41. void
  42. gemm4(auto && a, auto && b, auto & c)
  43. {
  44. dim_t M=a.len(0), N=b.len(1);
  45. for (int i=0; i<M; ++i) {
  46. for (int j=0; j<N; ++j) {
  47. c(i, j) = dot(a(i), b(all, j));
  48. }
  49. }
  50. }
  51. // -------------------
  52. // variants of the defaults, should be slower if the default is well picked.
  53. // -------------------
  54. template <class A, class B, class C>
  55. inline void
  56. gemm_block(ra::ViewBig<A, 2> const & a, ra::ViewBig<B, 2> const & b, ra::ViewBig<C, 2> c)
  57. {
  58. dim_t const m = a.len(0);
  59. dim_t const p = a.len(1);
  60. dim_t const n = b.len(1);
  61. // terminal, using reduce_k, see below
  62. if (max(m, max(p, n))<=64) {
  63. gemm(a, b, c);
  64. // split a's rows
  65. } else if (m>=max(p, n)) {
  66. gemm_block(a(ra::iota(m/2)), b, c(ra::iota(m/2)));
  67. gemm_block(a(ra::iota(m-m/2, m/2)), b, c(ra::iota(m-m/2, m/2)));
  68. // split b's columns
  69. } else if (n>=max(m, p)) {
  70. gemm_block(a, b(all, ra::iota(n/2)), c(all, ra::iota(n/2)));
  71. gemm_block(a, b(all, ra::iota(n-n/2, n/2)), c(all, ra::iota(n-n/2, n/2)));
  72. // split a's columns and b's rows
  73. } else {
  74. gemm_block(a(all, ra::iota(p/2)), b(ra::iota(p/2)), c);
  75. gemm_block(a(all, ra::iota(p-p/2, p/2)), b(ra::iota(p-p/2, p/2)), c);
  76. }
  77. }
  78. template <class PTR, class CPTR>
  79. void
  80. gemm_k_raw(auto const & a, auto const & b, auto & c)
  81. {
  82. dim_t const M = a.len(0);
  83. dim_t const N = b.len(1);
  84. dim_t const K = a.len(1);
  85. PTR cc = c.data();
  86. CPTR aa = a.data();
  87. CPTR bb = b.data();
  88. for (dim_t k=0; k<K; ++k) {
  89. for (dim_t i=0; i<M; ++i) {
  90. for (dim_t j=0; j<N; ++j) {
  91. cc[i*N+j] += aa[i*K+k] * bb[k*N+j];
  92. }
  93. }
  94. }
  95. }
  96. template <class PTR, class CPTR>
  97. void
  98. gemm_ij_raw(auto const & a, auto const & b, auto & c)
  99. {
  100. dim_t const M = a.len(0);
  101. dim_t const N = b.len(1);
  102. dim_t const K = a.len(1);
  103. PTR cc = c.data();
  104. CPTR aa = a.data();
  105. CPTR bb = b.data();
  106. for (dim_t i=0; i<M; ++i) {
  107. for (dim_t j=0; j<N; ++j) {
  108. for (dim_t k=0; k<K; ++k) {
  109. cc[i*N+j] += aa[i*K+k] * bb[k*N+j];
  110. }
  111. }
  112. }
  113. }
  114. #if RA_USE_BLAS==1
  115. extern "C" {
  116. #include <cblas.h>
  117. }
  118. constexpr CBLAS_TRANSPOSE
  119. fliptr(CBLAS_TRANSPOSE t)
  120. {
  121. if (t==CblasTrans) {
  122. return CblasNoTrans;
  123. } else if (t==CblasNoTrans) {
  124. return CblasTrans;
  125. } else {
  126. assert(0 && "BLAS doesn't support this transpose");
  127. abort();
  128. }
  129. }
  130. constexpr bool
  131. istr(CBLAS_TRANSPOSE t)
  132. {
  133. return (t==CblasTrans) || (t==CblasConjTrans);
  134. }
  135. template <class A> inline void
  136. lead_and_order(A const & a, int & ld, CBLAS_ORDER & order)
  137. {
  138. if (a.step(1)==1) {
  139. order = CblasRowMajor;
  140. ld = a.step(0);
  141. } else if (a.step(0)==1) {
  142. order = CblasColMajor;
  143. ld = a.step(1);
  144. } else {
  145. order = CblasRowMajor;
  146. ld = 0;
  147. assert(0 && "not a BLAS-supported array");
  148. }
  149. }
  150. template <class T>
  151. void
  152. gemm_blas(ra::ViewBig<T, 2> const & A, ra::ViewBig<T, 2> const & B, ra::ViewBig<T, 2> C)
  153. {
  154. CBLAS_TRANSPOSE ta = CblasNoTrans;
  155. CBLAS_TRANSPOSE tb = CblasNoTrans;
  156. int ldc, lda, ldb;
  157. CBLAS_ORDER orderc, ordera, orderb;
  158. lead_and_order(C, ldc, orderc);
  159. lead_and_order(A, lda, ordera);
  160. lead_and_order(B, ldb, orderb);
  161. int K = A.len(1-istr(ta));
  162. assert(K==B.len(istr(tb)) && "mismatched A/B");
  163. assert(C.len(0)==A.len(istr(ta)) && "mismatched C/A");
  164. assert(C.len(1)==B.len(1-istr(tb)) && "mismatched C/B");
  165. if (ordera!=orderc) {
  166. ta = fliptr(ta);
  167. }
  168. if (orderb!=orderc) {
  169. tb = fliptr(tb);
  170. }
  171. if (C.size()>0) {
  172. if constexpr (std::is_same_v<T, double>) {
  173. cblas_dgemm(orderc, ta, tb, C.len(0), C.len(1), K, real(1.), A.data(), lda, B.data(), ldb, 0, C.data(), ldc);
  174. } else if constexpr (std::is_same_v<T, float>) {
  175. cblas_sgemm(orderc, ta, tb, C.len(0), C.len(1), K, real(1.), A.data(), lda, B.data(), ldb, 0, C.data(), ldc);
  176. } else {
  177. abort();
  178. }
  179. }
  180. }
  181. #endif // RA_USE_BLAS
  182. int main()
  183. {
  184. TestRecorder tr(std::cout);
  185. cout << "RA_DO_FMA is " << RA_DO_FMA << endl;
  186. auto gemm_k = [&](auto const & a, auto const & b, auto & c)
  187. {
  188. dim_t const M = a.len(0);
  189. dim_t const N = b.len(1);
  190. for (dim_t i=0; i<M; ++i) {
  191. for (dim_t j=0; j<N; ++j) {
  192. c(i, j) = dot(a(i), b(all, j));
  193. }
  194. }
  195. return c;
  196. };
  197. auto bench_all = [&](int k, int m, int p, int n, int reps)
  198. {
  199. auto bench = [&](auto && f, char const * tag, real rerr=0)
  200. {
  201. ra::Big<real, 2> a({m, p}, ra::_0-ra::_1);
  202. ra::Big<real, 2> b({p, n}, ra::_1-2*ra::_0);
  203. ra::Big<real, 2> ref = gemm(a, b);
  204. ra::Big<real, 2> c({m, n}, 0.);
  205. auto bv = Benchmark().repeats(reps).runs(3).run([&]() { f(a, b, c); });
  206. tr.info(std::setw(5), std::fixed, Benchmark::avg(bv)/(m*n*p)/1e-9, " ns [",
  207. Benchmark::stddev(bv)/(m*n*p)/1e-9 ,"] ", tag).test_rel(ref, c, rerr);
  208. };
  209. tr.section(m, " (", p, ") ", n, " times ", reps);
  210. #define ZEROFIRST(GEMM) [&](auto const & a, auto const & b, auto & c) { c = 0; GEMM(a, b, c); }
  211. #define NOTZEROFIRST(GEMM) [&](auto const & a, auto const & b, auto & c) { GEMM(a, b, c); }
  212. // some variants are too slow to check with larger arrays.
  213. if (k>2) {
  214. bench(NOTZEROFIRST(gemm_k), "k");
  215. }
  216. if (k>0) {
  217. bench(ZEROFIRST((gemm_k_raw<real *, real const *>)), "k_raw");
  218. bench(ZEROFIRST((gemm_k_raw<real * __restrict__, real const * __restrict__>)), "k_raw_restrict");
  219. }
  220. if (k>0) {
  221. bench(ZEROFIRST((gemm_ij_raw<real *, real const *>)), "ij_raw");
  222. bench(ZEROFIRST((gemm_ij_raw<real * __restrict__, real const * __restrict__>)), "ij_raw_restrict");
  223. }
  224. bench(ZEROFIRST(gemm_block), "block");
  225. bench(ZEROFIRST(gemm1), "gemm1");
  226. bench(ZEROFIRST(gemm2), "gemm2");
  227. bench(ZEROFIRST(gemm3), "gemm3");
  228. bench(ZEROFIRST(gemm4), "gemm4");
  229. #if RA_USE_BLAS==1
  230. bench(ZEROFIRST(gemm_blas), "blas", 100*std::numeric_limits<real>::epsilon()); // ahem
  231. #endif
  232. bench(ZEROFIRST(gemm), "default");
  233. };
  234. bench_all(3, 4, 4, 4, 100);
  235. bench_all(3, 10, 10, 10, 10);
  236. bench_all(2, 100, 100, 100, 10);
  237. bench_all(2, 500, 400, 500, 1);
  238. bench_all(1, 10000, 10, 1000, 1);
  239. bench_all(1, 1000, 10, 10000, 1);
  240. bench_all(1, 100000, 10, 100, 1);
  241. bench_all(1, 100, 10, 100000, 1);
  242. return tr.summary();
  243. }