reduction.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra/test - Array reductions.
  3. // (c) Daniel Llorens - 2014-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #include <iostream>
  9. #include <iterator>
  10. #include "ra/test.hh"
  11. #include "mpdebug.hh"
  12. using std::cout, std::endl, std::flush, std::tuple, ra::TestRecorder;
  13. using real = double;
  14. using complex = std::complex<double>;
  15. using ra::sqrm;
  16. int main()
  17. {
  18. TestRecorder tr(std::cout);
  19. cout << "RA_DO_FMA is " << RA_DO_FMA << endl;
  20. tr.section("amax with different expr types");
  21. {
  22. auto test_amax_expr = [&tr](auto && a, auto && b)
  23. {
  24. a = ra::Small<real, 2, 2> {1, 2, 9, -10};
  25. tr.test_eq(amax(a), 9);
  26. b = ra::Small<real, 2, 2> {1, 1, 1, 1};
  27. tr.test_eq(amax(a+b), 10);
  28. };
  29. test_amax_expr(ra::Unique<real, 2>({2, 2}, 0.), ra::Unique<real, 2>({2, 2}, 0.));
  30. test_amax_expr(ra::Small<real, 2, 2>(), ra::Small<real, 2, 2>());
  31. // failed in gcc 5.1 when amax() took its args by plain auto (now auto &&).
  32. test_amax_expr(ra::Unique<real, 2>({2, 2}, 0.), ra::Small<real, 2, 2>());
  33. }
  34. tr.section("every / any");
  35. {
  36. tr.test(every(ra::Unique<real, 2>({4, 4}, 10+ra::_0-ra::_1)));
  37. tr.test(any(ra::Unique<real, 2>({4, 4}, ra::_0-ra::_1)));
  38. tr.test(ra::every(true));
  39. tr.test(!ra::every(false));
  40. tr.test(ra::any(true));
  41. tr.test(!ra::any(false));
  42. tr.test(every(ra::Unique<int, 1> {5, 5}==5));
  43. tr.test(!every(ra::Unique<int, 1> {2, 5}==5));
  44. tr.test(!every(ra::Unique<int, 1> {5, 2}==5));
  45. tr.test(!every(ra::Unique<int, 1> {2, 3}==5));
  46. tr.test(any(ra::Unique<int, 1> {5, 5}==5));
  47. tr.test(any(ra::Unique<int, 1> {2, 5}==5));
  48. tr.test(any(ra::Unique<int, 1> {5, 2}==5));
  49. tr.test(!any(ra::Unique<int, 1> {2, 3}==5));
  50. }
  51. tr.section("norm2");
  52. {
  53. ra::Small<real, 2> a {1, 2};
  54. tr.test_abs(std::sqrt(5.), norm2(a), 1e-15);
  55. ra::Small<float, 2> b {1, 2};
  56. tr.test_abs(std::sqrt(5.f), norm2(b), 4e-8);
  57. tr.info("type of norm2(floats)").test(std::is_same_v<float, decltype(norm2(b))>);
  58. tr.info("type of reduce_sqrm(floats)").test(std::is_same_v<float, decltype(reduce_sqrm(b))>);
  59. tr.info("type of sqrm(floats)").test(std::is_same_v<float, decltype(sqrm(b[0]))>);
  60. ra::Small<complex, 2> c {1, 2};
  61. tr.info("type of norm2(complex<double>)").test(std::is_same_v<double, decltype(norm2(c))>);
  62. tr.info("type of reduce_sqrm(complex<double>)").test(std::is_same_v<double, decltype(reduce_sqrm(c))>);
  63. tr.info("type of sqrm(complex<double>)").test(std::is_same_v<double, decltype(sqrm(c[0]))>);
  64. }
  65. tr.section("normv");
  66. {
  67. ra::Small<real, 2> a {1, 2};
  68. ra::Small<real, 2> b;
  69. b = normv(a);
  70. cout << "normv of lvalue: " << b << endl;
  71. tr.test_eq(b[0], 1./sqrt(5));
  72. tr.test_eq(b[1], 2./sqrt(5));
  73. b = normv(ra::Small<real, 2> {2, 1});
  74. cout << "normv of rvalue: "<< b << endl;
  75. tr.test_eq(b[0], 2./sqrt(5));
  76. tr.test_eq(b[1], 1./sqrt(5));
  77. }
  78. tr.section("reductions");
  79. {
  80. auto test_dot = [](auto && test) // TODO Use this for other real reductions.
  81. {
  82. test(ra::Small<complex, 2>{1, 2}, ra::Small<real, 2>{3, 4});
  83. test(ra::Small<real, 2>{1, 2}, ra::Small<complex, 2>{3, 4});
  84. test(ra::Small<real, 2>{1, 2}, ra::Small<real, 2>{3, 4});
  85. test(ra::Small<complex, 2>{1, 2}, ra::Small<complex, 2>{3, 4});
  86. test(ra::Big<complex, 1>{1, 2}, ra::Big<real, 1>{3, 4});
  87. test(ra::Big<real, 1>{1, 2}, ra::Big<complex, 1>{3, 4});
  88. test(ra::Big<real, 1>{1, 2}, ra::Big<real, 1>{3, 4});
  89. test(ra::Big<complex, 1>{1, 2}, ra::Big<complex, 1>{3, 4});
  90. test(ra::Small<complex, 2>{1, 2}, ra::Big<real, 1>{3, 4});
  91. test(ra::Small<real, 2>{1, 2}, ra::Big<complex, 1>{3, 4});
  92. test(ra::Small<real, 2>{1, 2}, ra::Big<real, 1>{3, 4});
  93. test(ra::Small<complex, 2>{1, 2}, ra::Big<complex, 1>{3, 4});
  94. test(ra::Big<complex, 1>{1, 2}, ra::Small<real, 2>{3, 4});
  95. test(ra::Big<real, 1>{1, 2}, ra::Small<complex, 2>{3, 4});
  96. test(ra::Big<real, 1>{1, 2}, ra::Small<real, 2>{3, 4});
  97. test(ra::Big<complex, 1>{1, 2}, ra::Small<complex, 2>{3, 4});
  98. };
  99. test_dot([&tr](auto && a, auto && b) { tr.test_eq(11., dot(a, b)); });
  100. test_dot([&tr](auto && a, auto && b) { tr.test_eq(11., cdot(a, b)); });
  101. test_dot([&tr](auto && a, auto && b) { tr.test_eq(sqrt(8.), norm2(a-b)); });
  102. test_dot([&tr](auto && a, auto && b) { tr.test_eq(8., reduce_sqrm(a-b)); });
  103. auto test_cdot = [](auto && test)
  104. {
  105. test(ra::Small<complex, 2>{1, complex(2, 3)}, ra::Small<complex, 2>{complex(4, 5), 6});
  106. test(ra::Big<complex, 1>{1, complex(2, 3)}, ra::Small<complex, 2>{complex(4, 5), 6});
  107. test(ra::Small<complex, 2>{1, complex(2, 3)}, ra::Big<complex, 1>{complex(4, 5), 6});
  108. test(ra::Big<complex, 1>{1, complex(2, 3)}, ra::Big<complex, 1>{complex(4, 5), 6});
  109. };
  110. complex value = conj(1.)*complex(4., 5.) + conj(complex(2., 3.))*6.;
  111. tr.test_eq(value, complex(16, -13));
  112. test_cdot([&tr](auto && a, auto && b) { tr.test_eq(complex(16., -13.), cdot(a, b)); });
  113. test_cdot([&tr](auto && a, auto && b) { tr.test_eq(sqrt(59.), norm2(a-b)); });
  114. test_cdot([&tr](auto && a, auto && b) { tr.test_eq(59., reduce_sqrm(a-b)); });
  115. auto test_sum = [](auto && test)
  116. {
  117. test(ra::Small<complex, 2>{complex(4, 5), 6});
  118. test(ra::Big<complex, 1>{complex(4, 5), 6});
  119. };
  120. test_sum([&tr](auto && a) { tr.test_eq(complex(10, 5), sum(a)); });
  121. test_sum([&tr](auto && a) { tr.test_eq(complex(24, 30), prod(a)); });
  122. test_sum([&tr](auto && a) { tr.test_eq(sqrt(41.), amax(abs(a))); });
  123. test_sum([&tr](auto && a) { tr.test_eq(6., amin(abs(a))); });
  124. }
  125. tr.section("amax/amin ignore NaN");
  126. {
  127. constexpr real QNAN = std::numeric_limits<real>::quiet_NaN();
  128. tr.test_eq(std::numeric_limits<real>::lowest(), std::max(std::numeric_limits<real>::lowest(), QNAN));
  129. tr.test_eq(-std::numeric_limits<real>::infinity(), amax(ra::Small<real, 3>(QNAN)));
  130. tr.test_eq(std::numeric_limits<real>::infinity(), amin(ra::Small<real, 3>(QNAN)));
  131. }
  132. // TODO these reductions require a destination argument; there are no exprs really.
  133. tr.section("to sum columns in crude ways");
  134. {
  135. ra::Unique<real, 2> A({100, 111}, ra::_0 - ra::_1);
  136. ra::Unique<real, 1> B({100}, 0.);
  137. for (int i=0, iend=A.len(0); i<iend; ++i) {
  138. B(i) = sum(A(i));
  139. }
  140. {
  141. ra::Unique<real, 1> C({100}, 0.);
  142. for_each([](auto & c, auto a) { c += a; }, C, A);
  143. tr.test_eq(B, C);
  144. }
  145. // This depends on matching frames for += just as for any other op, which is at odds with e.g. amend.
  146. {
  147. ra::Unique<real, 1> C({100}, 0.);
  148. C += A;
  149. tr.test_eq(B, C);
  150. }
  151. // Same as above.
  152. {
  153. ra::Unique<real, 1> C({100}, 0.);
  154. C = C + A;
  155. tr.test_eq(B, C);
  156. }
  157. // It cannot work with a lhs scalar value since += must be a class member, but it will work with a rank 0 array or with ra::Scalar.
  158. {
  159. ra::Unique<real, 0> C({}, 0.);
  160. C += A(0);
  161. tr.test_eq(B(0), C);
  162. real c(0.);
  163. ra::scalar(c) += A(0);
  164. tr.test_eq(B(0), c);
  165. }
  166. // This will fail because the assumed driver (ANY) has lower actual rank than the other argument. TODO check that it fails.
  167. // {
  168. // ra::Unique<real, 2> A({2, 3}, {1, 2, 3, 4 ,5, 6});
  169. // ra::Unique<real> C({}, 0.);
  170. // C += A(0);
  171. // }
  172. }
  173. tr.section("to sum rows in crude ways");
  174. {
  175. ra::Unique<real, 2> A({100, 111}, ra::_0 - ra::_1);
  176. ra::Unique<real, 1> B({111}, 0.);
  177. for (int j=0, jend=A.len(1); j<jend; ++j) {
  178. B(j) = sum(A(ra::all, j));
  179. }
  180. {
  181. ra::Unique<real, 1> C({111}, 0.);
  182. for_each([&C](auto && a) { C += a; }, A.iter<1>());
  183. tr.info("rhs iterator of rank > 0").test_eq(B, C);
  184. }
  185. {
  186. ra::Unique<real, 1> C({111}, 0.);
  187. for_each(ra::wrank<1, 1>([](auto & c, auto && a) { c += a; }), C, A);
  188. tr.info("rank conjuction").test_eq(B, C);
  189. }
  190. {
  191. ra::Unique<real, 1> C({111}, 0.);
  192. for_each(ra::wrank<1, 1>(ra::wrank<0, 0>([](auto & c, auto a) { c += a; })), C, A);
  193. tr.info("double rank conjunction").test_eq(B, C);
  194. }
  195. {
  196. ra::Unique<real, 1> C({111}, 0.);
  197. ra::scalar(C) += A.iter<1>();
  198. tr.info("scalar() and iterators of rank > 0").test_eq(B, C);
  199. }
  200. {
  201. ra::Unique<real, 1> C({111}, 0.);
  202. C.iter<1>() += A.iter<1>();
  203. tr.info("assign to iterators of rank > 0").test_eq(B, C);
  204. }
  205. }
  206. tr.section("reductions with amax");
  207. {
  208. ra::Big<int, 2> c({2, 3}, {1, 3, 2, 7, 1, 3});
  209. tr.info("max of rows").test_eq(ra::Big<int, 1> {3, 7}, map([](auto && a) { return amax(a); }, iter<1>(c)));
  210. ra::Big<int, 1> m({3}, 0);
  211. scalar(m) = max(scalar(m), iter<1>(c)); // requires inner forward in ra.hh: DEF_NAME_OP
  212. tr.info("max of columns I").test_eq(ra::Big<int, 1> {7, 3, 3}, m);
  213. m = 0;
  214. iter<1>(m) = max(iter<1>(m), iter<1>(c)); // FIXME
  215. tr.info("max of columns III [ma113]").test_eq(ra::Big<int, 1> {7, 3, 3}, m);
  216. m = 0;
  217. for_each([&m](auto && a) { m = max(m, a); }, iter<1>(c));
  218. tr.info("max of columns II").test_eq(ra::Big<int, 1> {7, 3, 3}, m);
  219. ra::Big<double, 1> q({0}, {});
  220. tr.info("amax default").test_eq(std::numeric_limits<double>::infinity(), amin(q));
  221. tr.info("amin default").test_eq(-std::numeric_limits<double>::infinity(), amax(q));
  222. }
  223. tr.section("vector-matrix reductions");
  224. {
  225. auto test = [&tr](auto t, auto s, auto r)
  226. {
  227. using T = decltype(t);
  228. using S = decltype(s);
  229. using R = decltype(r);
  230. S x[4] = {1, 2, 3, 4};
  231. ra::Small<T, 3, 4> a = ra::_0 - ra::_1;
  232. R y[3] = {99, 99, 99};
  233. ra::start(y) = ra::gemv(a, x);
  234. auto z = ra::gemv(a, x);
  235. tr.test_eq(ra::Small<R, 3> {-20, -10, 0}, y);
  236. tr.test_eq(ra::Small<R, 3> {-20, -10, 0}, z);
  237. };
  238. test(double(0), double(0), double(0));
  239. test(std::complex<double>(0), std::complex<double>(0), std::complex<double>(0));
  240. test(int(0), int(0), int(0));
  241. test(int(0), double(0), double(0));
  242. test(double(0), int(0), double(0));
  243. }
  244. {
  245. auto test = [&tr](auto t, auto s, auto r)
  246. {
  247. using T = decltype(t);
  248. using S = decltype(s);
  249. using R = decltype(r);
  250. S x[4] = {1, 2, 3, 4};
  251. ra::Small<T, 4, 3> a = ra::_1 - ra::_0;
  252. R y[3] = {99, 99, 99};
  253. ra::start(y) = ra::gevm(x, a);
  254. auto z = ra::gevm(x, a);
  255. tr.test_eq(ra::Small<R, 3> {-20, -10, 0}, y);
  256. tr.test_eq(ra::Small<R, 3> {-20, -10, 0}, z);
  257. };
  258. test(double(0), double(0), double(0));
  259. test(std::complex<double>(0), std::complex<double>(0), std::complex<double>(0));
  260. test(int(0), int(0), int(0));
  261. test(int(0), double(0), double(0));
  262. test(double(0), int(0), double(0));
  263. }
  264. tr.section("gemm with dynamic shape, corner case");
  265. {
  266. ra::Big<double, 2> A({0, 0}, 2.);
  267. ra::Big<double, 2> B({0, 0}, 3.);
  268. auto C = gemm(A, B);
  269. tr.test_eq(0, C.len(0));
  270. tr.test_eq(0, C.len(1));
  271. }
  272. tr.section("gemm with dynamic shape");
  273. {
  274. ra::Big<complex, 2> A({3, 2}, 2.);
  275. ra::Big<complex, 2> B({2, 4}, 3.);
  276. auto C = gemm(A, B);
  277. tr.test_eq(3, C.len(0));
  278. tr.test_eq(4, C.len(1));
  279. tr.test_eq(12., C);
  280. }
  281. tr.section("gemm with static shape");
  282. {
  283. ra::Small<double, 3, 2> A = 2;
  284. ra::Small<double, 2, 4> B = 3;
  285. auto C = gemm(A, B);
  286. tr.test_eq(3, C.len_s(0));
  287. tr.test_eq(4, C.len_s(1));
  288. tr.test_eq(12, C);
  289. }
  290. tr.section("gemv with static shape");
  291. {
  292. ra::Small<double, 3, 2> A = 2;
  293. ra::Small<double, 2> B = 3;
  294. auto C = gemv(A, B);
  295. tr.test_eq(3, C.len_s(0));
  296. tr.test_eq(12, C);
  297. }
  298. tr.section("gevm with static shape");
  299. {
  300. ra::Small<double, 2> A = 3;
  301. ra::Small<double, 2, 3> B = 2;
  302. auto C = gevm(A, B);
  303. tr.test_eq(3, C.len_s(0));
  304. tr.test_eq(12, C);
  305. }
  306. tr.section("reference reductions");
  307. {
  308. ra::Big<double, 2> A({2, 3}, ra::_1 - ra::_0);
  309. double & mn = refmin(A);
  310. tr.test_eq(-1, mn);
  311. mn = -99;
  312. ra::Big<double, 2> B({2, 3}, ra::_1 - ra::_0);
  313. B(1, 0) = -99;
  314. tr.test_eq(B, A);
  315. double & mx = refmin(A, std::greater<double>());
  316. tr.test_eq(2, mx);
  317. mx = 0;
  318. B(0, 2) = 0;
  319. tr.test_eq(B, A);
  320. double & my = refmax(A);
  321. tr.test_eq(1, my);
  322. my = 77;
  323. B(0, 1) = 77;
  324. tr.test_eq(B, A);
  325. // cout << refmin(A+B) << endl; // compile error
  326. }
  327. return tr.summary();
  328. }