wrank.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra/test - Checks operations (verbs) with cell rank>0.
  3. // (c) Daniel Llorens - 2013-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #include <atomic>
  9. #include <numeric>
  10. #include <sstream>
  11. #include <iostream>
  12. #include <iterator>
  13. #include "ra/test.hh"
  14. #include "mpdebug.hh"
  15. using std::cout, std::endl, std::flush, std::tuple, ra::dim_t, ra::TestRecorder;
  16. using real = double;
  17. // Find the driver for given axis (This isn't used anymore in ra::; see ra::Match).
  18. template <int iarg, class T>
  19. constexpr int
  20. driver(T && t, int k)
  21. {
  22. if constexpr (iarg<ra::mp::len<std::decay_t<T>>) {
  23. if (k<std::get<iarg>(t).rank()) {
  24. dim_t s = std::get<iarg>(t).len(k);
  25. if (s>=0) {
  26. return iarg;
  27. }
  28. }
  29. return driver<iarg+1>(t, k);
  30. } else {
  31. std::abort(); // no driver
  32. }
  33. }
  34. // ewv = expression-with-verb
  35. template <class V, class A, class B>
  36. void nested_wrank_demo(V && v, A && a, B && b)
  37. {
  38. std::iota(a.begin(), a.end(), 10);
  39. std::iota(b.begin(), b.end(), 1);
  40. {
  41. using FM = ra::Framematch<V, tuple<decltype(a.iter()), decltype(b.iter())>>;
  42. cout << "width of fm: " << ra::mp::len<typename FM::R> << endl;
  43. cout << ra::mp::print_int_list<typename FM::R> {} << endl;
  44. auto af0 = ra::reframe<ra::mp::ref<typename FM::R, 0>>(a.iter());
  45. auto af1 = ra::reframe<ra::mp::ref<typename FM::R, 1>>(b.iter());
  46. cout << "af0: " << sizeof(af0) << endl;
  47. cout << "af1: " << sizeof(af1) << endl;
  48. {
  49. auto ewv = ra::expr(FM::op(v), af0, af1);
  50. cout << sizeof(ewv) << endl;
  51. cout << "ewv rank I: " << ewv.rank() << endl;
  52. for (int k=0; k<ewv.rank(); ++k) {
  53. cout << ewv.len(k) << ": " << driver<0>(ewv.t, k) << endl;
  54. }
  55. // cout << ra::mp::show<decltype(ra::ewv<FM>(FM::op(v), af0, af1))>::value << endl;
  56. cout << "\nusing (ewv &):\n";
  57. ra::ply_ravel(ewv);
  58. cout << endl;
  59. cout << "\nusing (ewv &&):\n";
  60. ra::ply_ravel(ra::expr(FM::op(v), af0, af1));
  61. }
  62. {
  63. // cout << ra::mp::show<decltype(ra::expr(v, a.iter(), b.iter()))>::value << endl;
  64. auto ewv = ra::expr(v, a.iter(), b.iter());
  65. cout << "shape(ewv): " << ra::noshape << shape(ewv) << endl;
  66. #define TEST(plier) \
  67. cout << "\n\nusing " STRINGIZE(plier) " (ewv &):\n"; \
  68. ra::plier(ewv); \
  69. cout << "\n\nusing " STRINGIZE(plier) " ply (ewv &&):\n"; \
  70. ra::plier(ra::expr(v, a.iter(), b.iter()));
  71. TEST(ply_ravel);
  72. TEST(ply_fixed);
  73. }
  74. cout << "\n\n" << endl;
  75. }
  76. }
  77. int main()
  78. {
  79. TestRecorder tr;
  80. auto plus2real = [](real a, real b) { return a + b; };
  81. tr.section("declaring verbs");
  82. {
  83. auto v = ra::wrank<0, 1>(plus2real);
  84. cout << ra::mp::ref<decltype(v)::cranks, 0>::value << endl;
  85. cout << ra::mp::ref<decltype(v)::cranks, 1>::value << endl;
  86. auto vv = ra::wrank<1, 1>(v);
  87. cout << ra::mp::ref<decltype(vv)::cranks, 0>::value << endl;
  88. cout << ra::mp::ref<decltype(vv)::cranks, 1>::value << endl;
  89. static_assert(ra::is_verb<decltype(v)>);
  90. static_assert(!ra::is_verb<decltype(plus2real)>);
  91. }
  92. tr.section("using Framematch");
  93. {
  94. ra::Unique<real, 2> a({3, 2}, ra::none);
  95. ra::Unique<real, 2> b({3, 2}, ra::none);
  96. std::iota(a.begin(), a.end(), 10);
  97. std::iota(b.begin(), b.end(), 1);
  98. auto plus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  99. {
  100. auto v = ra::wrank<0, 2>(plus2real_print);
  101. using FM = ra::Framematch<decltype(v), tuple<decltype(a.iter()), decltype(b.iter())>>;
  102. cout << "width of fm: " << ra::mp::len<FM::R> << endl;
  103. cout << ra::mp::print_int_list<FM::R> {} << endl;
  104. auto af0 = ra::reframe<ra::mp::ref<FM::R, 0>>(a.iter());
  105. auto af1 = ra::reframe<ra::mp::ref<FM::R, 1>>(b.iter());
  106. cout << "af0: " << sizeof(af0) << endl;
  107. cout << "af1: " << sizeof(af1) << endl;
  108. auto ewv = expr(FM::op(v), af0, af1);
  109. cout << sizeof(ewv) << "\n" << endl;
  110. cout << "ewv rank II: " << ewv.rank() << endl;
  111. for (int k=0; k<ewv.rank(); ++k) {
  112. cout << ewv.len(k) << ": " << flush << driver<0>(ewv.t, k) << endl;
  113. }
  114. ra::ply_ravel(ewv);
  115. }
  116. }
  117. tr.section("wrank tests 0-1");
  118. {
  119. auto minus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  120. nested_wrank_demo(ra::wrank<0, 1>(minus2real_print),
  121. ra::Unique<real, 1>({3}, ra::none),
  122. ra::Unique<real, 1>({4}, ra::none));
  123. nested_wrank_demo(ra::wrank<0, 1>(ra::wrank<0, 0>(minus2real_print)),
  124. ra::Unique<real, 1>({3}, ra::none),
  125. ra::Unique<real, 1>({3}, ra::none));
  126. }
  127. tr.section("wrank tests 1-0");
  128. {
  129. auto minus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  130. nested_wrank_demo(ra::wrank<1, 0>(minus2real_print),
  131. ra::Unique<real, 1>({3}, ra::none),
  132. ra::Unique<real, 1>({4}, ra::none));
  133. nested_wrank_demo(ra::wrank<1, 0>(ra::wrank<0, 0>(minus2real_print)),
  134. ra::Unique<real, 1>({3}, ra::none),
  135. ra::Unique<real, 1>({4}, ra::none));
  136. }
  137. tr.section("wrank tests 0-0 (nop), case 1 - exact match");
  138. {
  139. // This uses the reframe specialization for 'do nothing' (TODO if there's one).
  140. auto minus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  141. nested_wrank_demo(ra::wrank<0, 0>(minus2real_print),
  142. ra::Unique<real, 1>({3}, ra::none),
  143. ra::Unique<real, 1>({3}, ra::none));
  144. }
  145. tr.section("wrank tests 0-0 (nop), case 2 - non-exact frame match");
  146. {
  147. // This uses the reframe specialization for 'do nothing' (TODO if there's one).
  148. auto minus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  149. nested_wrank_demo(ra::wrank<0, 0>(minus2real_print),
  150. ra::Unique<real, 2>({3, 4}, ra::none),
  151. ra::Unique<real, 1>({3}, ra::none));
  152. nested_wrank_demo(ra::wrank<0, 0>(minus2real_print),
  153. ra::Unique<real, 1>({3}, ra::none),
  154. ra::Unique<real, 2>({3, 4}, ra::none));
  155. }
  156. tr.section("wrank tests 1-1-0, init array with outer product");
  157. {
  158. auto minus2real = [](real & c, real a, real b) { c = a-b; };
  159. ra::Unique<real, 1> a({3}, ra::none);
  160. ra::Unique<real, 1> b({4}, ra::none);
  161. std::iota(a.begin(), a.end(), 10);
  162. std::iota(b.begin(), b.end(), 1);
  163. ra::Unique<real, 2> c({3, 4}, ra::none);
  164. ra::ply(ra::expr(ra::wrank<1, 0, 1>(minus2real), c.iter(), a.iter(), b.iter()));
  165. real checkc34[3*4] = { /* 10-[1 2 3 4] */ 9, 8, 7, 6,
  166. /* 11-[1 2 3 4] */ 10, 9, 8, 7,
  167. /* 12-[1 2 3 4] */ 11, 10, 9, 8 };
  168. tr.test(std::equal(checkc34, checkc34+3*4, c.begin()));
  169. ra::Unique<real, 2> d34(ra::expr(ra::wrank<0, 1>(std::minus<real>()), a.iter(), b.iter()));
  170. tr.test(std::equal(checkc34, checkc34+3*4, d34.begin()));
  171. real checkc43[3*4] = { /* [10 11 12]-1 */ 9, 10, 11,
  172. /* [10 11 12]-2 */ 8, 9, 10,
  173. /* [10 11 12]-3 */ 7, 8, 9,
  174. /* [10 11 12]-4 */ 6, 7, 8 };
  175. ra::Unique<real, 2> d43(ra::expr(ra::wrank<1, 0>(std::minus<real>()), a.iter(), b.iter()));
  176. tr.test(d43.len(0)==4 && d43.len(1)==3);
  177. tr.test(std::equal(checkc43, checkc43+3*4, d43.begin()));
  178. }
  179. tr.section("recipe for unbeatable subscripts in _from_ operator");
  180. {
  181. ra::Unique<int, 1> a({3}, ra::none);
  182. ra::Unique<int, 1> b({4}, ra::none);
  183. std::iota(a.begin(), a.end(), 10);
  184. std::iota(b.begin(), b.end(), 1);
  185. ra::Unique<real, 2> c({100, 100}, ra::none);
  186. std::iota(c.begin(), c.end(), 0);
  187. real checkd[3*4] = { 1001, 1002, 1003, 1004, 1101, 1102, 1103, 1104, 1201, 1202, 1203, 1204 };
  188. // default auto is value, so need to speficy.
  189. #define EXPR ra::expr(ra::wrank<0, 1>([&c](int a, int b) -> decltype(auto) { return c(a, b); } ), \
  190. a.iter(), b.iter())
  191. std::ostringstream os;
  192. os << EXPR << endl;
  193. ra::Unique<real, 2> cc {};
  194. std::istringstream is(os.str());
  195. is >> cc;
  196. tr.test(std::equal(checkd, checkd+3*4, cc.begin()));
  197. ra::Unique<real, 2> d(EXPR);
  198. tr.test(std::equal(checkd, checkd+3*4, d.begin()));
  199. // Using expr as lvalue.
  200. EXPR = 7.;
  201. tr.test_eq(c, where(ra::_0>=10 && ra::_0<=12 && ra::_1>=1 && ra::_1<=4, 7, ra::_0*100+ra::_1));
  202. // looping...
  203. bool valid = true;
  204. for (int i=0; i<c.len(0); ++i) {
  205. for (int j=0; j<c.len(1); ++j) {
  206. valid = valid && ((i>=10 && i<=12 && j>=1 && j<=4 ? 7 : i*100+j) == c(i, j));
  207. }
  208. }
  209. tr.test(valid);
  210. }
  211. tr.section("rank conjunction / empty");
  212. {
  213. }
  214. tr.section("static rank() in ra::expr with reframe()d args");
  215. {
  216. ra::Unique<real, 3> a({2, 2, 2}, 1.);
  217. ra::Unique<real, 3> b({2, 2, 2}, 2.);
  218. real y = 0;
  219. auto e = ra::expr(ra::wrank<0, 0>([&y](real const a, real const b) { y += a*b; }), a.iter(), b.iter());
  220. static_assert(3==e.rank(), "bad rank in static rank expr");
  221. ra::ply_ravel(ra::expr(ra::wrank<0, 0>([&y](real const a, real const b) { y += a*b; }), a.iter(), b.iter()));
  222. tr.test_eq(16, y);
  223. }
  224. tr.section("outer product variants");
  225. {
  226. ra::Big<real, 2> a({2, 3}, ra::_0 - ra::_1);
  227. ra::Big<real, 2> b({3, 2}, ra::_1 - 2*ra::_0);
  228. ra::Big<real, 2> c1 = gemm(a, b);
  229. // matrix product as outer product + reduction (no reductions yet, so manually).
  230. {
  231. ra::Big<real, 3> d = ra::expr(ra::wrank<1, 2>(ra::wrank<0, 1>(std::multiplies<>())), start(a), start(b));
  232. ra::Big<real, 2> c2({d.len(0), d.len(2)}, 0.);
  233. for (int k=0; k<d.len(1); ++k) {
  234. c2 += d(ra::all, k, ra::all);
  235. }
  236. tr.info("d(i,k,j) = a(i,k)*b(k,j)").test_eq(c1, c2);
  237. }
  238. // do the k-reduction by plying with wrank.
  239. {
  240. ra::Big<real, 2> c2({a.len(0), b.len(1)}, 0.);
  241. ra::ply(ra::expr(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto & c, auto && a, auto && b) { c += a*b; })),
  242. start(c2), start(a), start(b)));
  243. tr.info("sum_k a(i,k)*b(k,j)").test_eq(c1, c2);
  244. }
  245. }
  246. tr.section("stencil test for Reframe::keep_step. Reduced from test/bench-stencil2.cc");
  247. {
  248. int nx = 4;
  249. int ny = 4;
  250. int ts = 4; // must be even bc of swap
  251. auto I = ra::iota(nx-2, 1);
  252. auto J = ra::iota(ny-2, 1);
  253. constexpr ra::Small<real, 3, 3> mask = { 0, 1, 0,
  254. 1, -4, 1,
  255. 0, 1, 0 };
  256. real value = 1;
  257. auto f_raw = [&](ra::ViewBig<real, 2> & A, ra::ViewBig<real, 2> & Anext, ra::ViewBig<real, 4> & Astencil)
  258. {
  259. for (int t=0; t<ts; ++t) {
  260. for (int i=1; i+1<nx; ++i) {
  261. for (int j=1; j+1<ny; ++j) {
  262. Anext(i, j) = -4*A(i, j)
  263. + A(i+1, j) + A(i, j+1)
  264. + A(i-1, j) + A(i, j-1);
  265. }
  266. }
  267. std::swap(A.cp, Anext.cp);
  268. }
  269. };
  270. auto f_sumprod = [&](ra::ViewBig<real, 2> & A, ra::ViewBig<real, 2> & Anext, ra::ViewBig<real, 4> & Astencil)
  271. {
  272. for (int t=0; t!=ts; ++t) {
  273. Astencil.cp = A.data();
  274. Anext(I, J) = 0; // TODO miss notation for sum-of-axes without preparing destination...
  275. Anext(I, J) += map(ra::wrank<2, 2>(std::multiplies<>()), Astencil, mask);
  276. std::swap(A.cp, Anext.cp);
  277. }
  278. };
  279. auto bench = [&](auto & A, auto & Anext, auto & Astencil, auto && ref, auto && tag, auto && f)
  280. {
  281. A = value;
  282. Anext = 0.;
  283. f(A, Anext, Astencil);
  284. tr.info(tag).test_rel(ref, A, 1e-11);
  285. };
  286. ra::Big<real, 2> Aref;
  287. ra::Big<real, 2> A({nx, ny}, 1.);
  288. ra::Big<real, 2> Anext({nx, ny}, 0.);
  289. auto Astencil = stencil(A, 1, 1);
  290. #define BENCH(ref, op) bench(A, Anext, Astencil, ref, STRINGIZE(op), op);
  291. BENCH(A, f_raw);
  292. Aref = ra::Big<real, 2>(A);
  293. BENCH(Aref, f_sumprod);
  294. }
  295. tr.section("iota with dead axes");
  296. {
  297. ra::Big<int, 2> a = from([](auto && i, auto && j) { return i-j; }, ra::iota(3), ra::iota(3));
  298. tr.test_eq(ra::Big<int, 2>({3, 3}, {0, -1, -2, 1, 0, -1, 2, 1, 0}), a);
  299. }
  300. tr.section("vector with dead axes");
  301. {
  302. std::vector i = {0, 1, 2};
  303. ra::Big<int, 2> a = ra::from([](auto && i, auto && j) { return i-j; }, i, i);
  304. tr.test_eq(ra::Big<int, 2>({3, 3}, {0, -1, -2, 1, 0, -1, 2, 1, 0}), a);
  305. }
  306. tr.section("no arguments -> zero rank");
  307. {
  308. int x = ra::from([] { return 3; });
  309. tr.test_eq(3, x);
  310. }
  311. tr.section("counting ops");
  312. {
  313. std::atomic<int> i { 0 };
  314. auto fi = [&i](auto && x) { ++i; return x; };
  315. std::atomic<int> j { 0 };
  316. auto fj = [&j](auto && x) { ++j; return x; };
  317. ra::Big<int, 2> a = from(std::minus<>(), map(fi, ra::iota(7)), map(fj, ra::iota(9)));
  318. tr.test_eq(ra::_0-ra::_1, a);
  319. tr.info("FIXME").skip().test_eq(7, int(i));
  320. tr.info("FIXME").skip().test_eq(9, int(j));
  321. }
  322. tr.section("explicit agreement check with wrank");
  323. {
  324. {
  325. ra::Big<real, 1> a({3}, 2+ra::_0);
  326. ra::Big<real, 2> b({2, 3}, 4-ra::_0);
  327. ra::Big<real, 2> c = {{8, 12, 16}, {6, 9, 12}};
  328. tr.test(agree_op(ra::wrank<1, 1>(std::multiplies<>()), a, b));
  329. tr.test_eq(c, map(ra::wrank<1, 1>(std::multiplies<>()), a, b));
  330. ra::Big<real, 2> d({3, 2}, 4-ra::_0);
  331. tr.test(!agree_op(ra::wrank<1, 1>(std::multiplies<>()), a, d));
  332. }
  333. }
  334. tr.section("Reframe::at");
  335. {
  336. ra::Small<int, 3> a = ra::iota(3)-10;
  337. ra::Small<int, 4> b = 10-ra::iota(4);
  338. ra::Small<int, 3, 4> c = from(std::multiplies<>(), a, b);
  339. tr.strictshape()
  340. .test_eq(c, from([o = from(std::multiplies<>(), a, b)](auto i, auto j) { return o.at(ra::Small<dim_t, 2> {i, j}); },
  341. ra::iota(3), ra::iota(4)));
  342. }
  343. return tr.summary();
  344. }