bench-gemv.C 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. // (c) Daniel Llorens - 2017
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file bench-gemv.H
  7. /// @brief Benchmark for BLAS-2 type ops
  8. // These operations aren't really part of the ET framework, just standalone
  9. // functions.
  10. // Cf bench-gemm.C for BLAS-3 type ops.
  11. #include <iostream>
  12. #include <iomanip>
  13. #include "ra/test.H"
  14. #include "ra/complex.H"
  15. #include "ra/format.H"
  16. #include "ra/big.H"
  17. #include "ra/operators.H"
  18. #include "ra/io.H"
  19. #include "ra/bench.H"
  20. using std::cout, std::endl, std::setw, std::setprecision;
  21. using ra::Small, ra::View, ra::Unique, ra::ra_traits;
  22. using real = double;
  23. // -------------------
  24. // variants of the defaults, should be slower if the default is well picked.
  25. // TODO compare with external GEMV/GEVM
  26. // -------------------
  27. enum trans_t { NOTRANS, TRANS };
  28. int main()
  29. {
  30. TestRecorder tr(std::cout);
  31. auto gemv_i = [&](auto const & a, auto const & b)
  32. {
  33. int const M = a.size(0);
  34. ra::Big<decltype(a(0, 0)*b(0)), 1> c({M}, ra::none);
  35. for (int i=0; i<M; ++i) {
  36. c(i) = dot(a(i), b);
  37. }
  38. return c;
  39. };
  40. auto gemv_j = [&](auto const & a, auto const & b)
  41. {
  42. int const M = a.size(0);
  43. int const N = a.size(1);
  44. ra::Big<decltype(a(0, 0)*b(0)), 1> c({M}, 0.);
  45. for (int j=0; j<N; ++j) {
  46. c += a(ra::all, j)*b(j);
  47. }
  48. return c;
  49. };
  50. auto gevm_j = [&](auto const & b, auto const & a)
  51. {
  52. int const N = a.size(1);
  53. ra::Big<decltype(b(0)*a(0, 0)), 1> c({N}, ra::none);
  54. for (int j=0; j<N; ++j) {
  55. c(j) = dot(b, a(ra::all, j));
  56. }
  57. return c;
  58. };
  59. auto gevm_i = [&](auto const & b, auto const & a)
  60. {
  61. int const M = a.size(0);
  62. int const N = a.size(1);
  63. ra::Big<decltype(b(0)*a(0, 0)), 1> c({N}, 0.);
  64. for (int i=0; i<M; ++i) {
  65. c += b(i)*a(i);
  66. }
  67. return c;
  68. };
  69. auto bench_all = [&](int k, int m, int n, int reps)
  70. {
  71. auto bench_mv = [&tr, &m, &n, &reps](auto && f, char const * tag, trans_t t)
  72. {
  73. ra::Big<real, 2> aa({m, n}, ra::_0-ra::_1);
  74. auto a = t==TRANS ? transpose<1, 0>(aa) : aa();
  75. ra::Big<real, 1> b({a.size(1)}, 1-2*ra::_0);
  76. ra::Big<real, 1> ref = gemv(a, b);
  77. ra::Big<real, 1> c;
  78. auto bv = Benchmark().repeats(reps).runs(3).run([&]() { c = f(a, b); });
  79. tr.info(std::setw(5), std::fixed, Benchmark::avg(bv)/(m*n)/1e-9, " ns [",
  80. Benchmark::stddev(bv)/(m*n)/1e-9 ,"] ", tag, t==TRANS ? " [T]" : " [N]").test_eq(ref, c);
  81. };
  82. auto bench_vm = [&tr, &m, &n, &reps](auto && f, char const * tag, trans_t t)
  83. {
  84. ra::Big<real, 2> aa({m, n}, ra::_0-ra::_1);
  85. auto a = t==TRANS ? transpose<1, 0>(aa) : aa();
  86. ra::Big<real, 1> b({a.size(0)}, 1-2*ra::_0);
  87. ra::Big<real, 1> ref = gevm(b, a);
  88. ra::Big<real, 1> c;
  89. auto bv = Benchmark().repeats(reps).runs(4).run([&]() { c = f(b, a); });
  90. tr.info(std::setw(5), std::fixed, Benchmark::avg(bv)/(m*n)/1e-9, " ns [",
  91. Benchmark::stddev(bv)/(m*n)/1e-9 ,"] ", tag, t==TRANS ? " [T]" : " [N]").test_eq(ref, c);
  92. };
  93. tr.section(m, " x ", n, " times ", reps);
  94. // some variants are way too slow to check with larger arrays.
  95. if (k>0) {
  96. bench_mv(gemv_i, "mv i", NOTRANS);
  97. bench_mv(gemv_i, "mv i", TRANS);
  98. bench_mv(gemv_j, "mv j", NOTRANS);
  99. bench_mv(gemv_j, "mv j", TRANS);
  100. bench_mv([&](auto const & a, auto const & b) { return gemv(a, b); }, "mv default", NOTRANS);
  101. bench_mv([&](auto const & a, auto const & b) { return gemv(a, b); }, "mv default", TRANS);
  102. bench_vm(gevm_i, "vm i", NOTRANS);
  103. bench_vm(gevm_i, "vm i", TRANS);
  104. bench_vm(gevm_j, "vm j", NOTRANS);
  105. bench_vm(gevm_j, "vm j", TRANS);
  106. bench_vm([&](auto const & a, auto const & b) { return gevm(a, b); }, "vm default", NOTRANS);
  107. bench_vm([&](auto const & a, auto const & b) { return gevm(a, b); }, "vm default", TRANS);
  108. }
  109. };
  110. bench_all(3, 10, 10, 10000);
  111. bench_all(3, 100, 100, 100);
  112. bench_all(3, 500, 500, 1);
  113. bench_all(3, 10000, 1000, 1);
  114. bench_all(3, 1000, 10000, 1);
  115. bench_all(3, 100000, 100, 1);
  116. bench_all(3, 100, 100000, 1);
  117. return tr.summary();
  118. }