fromb.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra/test - Checks for index selectors, esp. immediate. See fromu.cc.
  3. // (c) Daniel Llorens - 2014-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #include <iostream>
  9. #include <iterator>
  10. #include "ra/test.hh"
  11. using std::cout, std::endl, std::flush, std::tuple, ra::TestRecorder;
  12. using real = double;
  13. template <int rank=ra::ANY> using Ureal = ra::Unique<real, rank>;
  14. using Vint = ra::Unique<int, 1>;
  15. int main()
  16. {
  17. TestRecorder tr(std::cout);
  18. tr.section("beating Small with static iota");
  19. {
  20. ra::Small<int, 10> a = ra::_0;
  21. {
  22. auto b = a(ra::iota(ra::ic<4>));
  23. tr.test_eq(ra::Small<int, 4>(ra::_0), b);
  24. tr.test_eq(ra::scalar(a.data()), ra::scalar(b.data()));
  25. }
  26. {
  27. auto b = a(ra::iota(ra::len-ra::ic<5>, 5));
  28. tr.test_eq(ra::Small<int, 5>(ra::_0+5), b);
  29. // FIXME see "static len is preserved" in len.cc
  30. // tr.test_eq(ra::scalar(a.data()+5), ra::scalar(b.data()));
  31. }
  32. {
  33. auto b = a(ra::iota(4));
  34. tr.test_eq(ra::Small<int, 4>(ra::_0), b);
  35. // not beaten
  36. }
  37. }
  38. {
  39. ra::Small<int, 10, 10> a = ra::_1 + 10*ra::_0;
  40. {
  41. auto b = a(3, ra::all);
  42. tr.test_eq(ra::Small<int, 10>(30+ra::_0), b);
  43. tr.test_eq(ra::scalar(a.data()+30), ra::scalar(b.data()));
  44. }
  45. {
  46. auto b = a(ra::iota(ra::ic<4>));
  47. tr.test_eq(ra::Small<int, 4, 10>(ra::_1 + 10*ra::_0), b);
  48. tr.test_eq(ra::scalar(a.data()), ra::scalar(b.data()));
  49. }
  50. {
  51. auto b = a(ra::iota(ra::ic<4>, 4));
  52. tr.test_eq(ra::Small<int, 4, 10>(ra::_1 + 10*(ra::_0+4)), b);
  53. tr.test_eq(ra::scalar(a.data()+40), ra::scalar(b.data()));
  54. }
  55. {
  56. auto b = a(3, ra::iota(ra::ic<5>, 4));
  57. tr.test_eq(ra::Small<int, 5>(30+ra::_0+4), b);
  58. tr.test_eq(ra::scalar(a.data()+30+4), ra::scalar(b.data()));
  59. }
  60. {
  61. auto b = a(ra::all, ra::iota(ra::ic<4>, 2, ra::ic<2>));
  62. tr.test_eq(ra::Small<int, 10, 4>(10*ra::_0 + 2*(1+ra::_1)), b);
  63. tr.test_eq(ra::scalar(a.data()+2), ra::scalar(b.data()));
  64. }
  65. {
  66. auto b = a(ra::iota(ra::ic<3>, 1, ra::ic<2>),
  67. ra::iota(ra::ic<2>, 2, ra::ic<3>));
  68. tr.test_eq(ra::Small<int, 3, 2>(10*(1+2*ra::_0) + 2+ra::_1*3), b);
  69. tr.test_eq(ra::scalar(a.data()+12), ra::scalar(b.data()));
  70. }
  71. {
  72. auto b = a(ra::iota(ra::ic<3>, 9, ra::ic<-2>),
  73. ra::iota(ra::ic<2>, 2, ra::ic<3>));
  74. tr.test_eq(ra::Small<int, 3, 2>(10*(9-2*ra::_0) + 2+ra::_1*3), b);
  75. tr.test_eq(ra::scalar(a.data()+92), ra::scalar(b.data()));
  76. }
  77. // FIXME the unbeaten path caused by rt iota results in a nested rank expr [ra33]
  78. {
  79. cout << a(ra::iota(4)) << endl;
  80. // tr.test_eq(ra::Small<int, 4, 10>(ra::_1 + 10*ra::_0), a(ra::iota(4)));
  81. }
  82. // FIXME the unbeaten path caused by rt iota fails bc ra::all isn't an expr, just a 'special object' for subscripts. So we can't even print.
  83. {
  84. // cout << a(ra::all, ra::iota(4)) << endl;
  85. // tr.test_eq(ra::Small<int, 10, 4>(ra::_1 + 10*ra::_0), a(ra::all, ra::iota(4)));
  86. }
  87. // FIXME static iota(expr(ra::len) ...)
  88. }
  89. tr.section("zero length iota");
  90. {
  91. // 1-past is ok but 1-before is not, so these just leave the pointer unchanged.
  92. {
  93. ra::Small<int, 10> a = ra::_0;
  94. auto b = a(ra::iota(ra::ic<0>, 10));
  95. tr.test_eq(ra::Small<int, 0>(ra::_0+10), b);
  96. tr.test_eq(ra::scalar(a.data()), ra::scalar(b.data()));
  97. }
  98. {
  99. ra::Small<int, 10> a = ra::_0;
  100. auto b = a(ra::iota(ra::ic<0>, 10, ra::ic<-1>));
  101. tr.test_eq(ra::Small<int, 0>(ra::_0-1), b);
  102. cout << "a " << a.data() << " b " << b.data() << endl;
  103. tr.test_eq(ra::scalar(a.data()), ra::scalar(b.data()));
  104. }
  105. }
  106. tr.section("Iota<T> is beatable for any integral T");
  107. {
  108. Ureal<2> a({4, 4}, 0.);
  109. auto test = [&](auto org)
  110. {
  111. auto i = ra::iota(2, org);
  112. static_assert(std::is_same_v<decltype(i.i), decltype(org)>);
  113. auto b = a(i);
  114. tr.test_eq(2, b.dimv[0].len);
  115. tr.test_eq(4, b.dimv[1].len);
  116. tr.test_eq(4, b.dimv[0].step);
  117. tr.test_eq(1, b.dimv[1].step);
  118. };
  119. test(int(1));
  120. test(int16_t(1));
  121. test(ra::dim_t(1));
  122. }
  123. tr.section("trivial case");
  124. {
  125. ra::Big<int, 3> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2);
  126. tr.test_eq(ra::_0*100 + ra::_1*10 + ra::_2, from(a));
  127. }
  128. tr.section("scalar len (var size)");
  129. {
  130. ra::Big<int, 3> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + (2 - ra::_2));
  131. tr.test_eq(a(1, 0, 0), a(ra::len-1, 0, 0));
  132. tr.test_eq(a(0, 2, 0), a(0, ra::len-1, 0));
  133. tr.test_eq(a(0, 0, 3), a(0, 0, ra::len-1));
  134. }
  135. tr.section("scalar len (static size)");
  136. {
  137. ra::Small<int, 4, 3, 2> a = ra::_0 - 10*ra::_1 + 100*ra::_2;
  138. tr.test_eq(a(3, 0, 0), a(ra::len-1, 0, 0));
  139. tr.test_eq(a(0, 2, 0), a(0, ra::len-1, 0));
  140. tr.test_eq(a(0, 0, 1), a(0, 0, ra::len-1));
  141. tr.test_eq(a(3, 2, 1), a(ra::len-1, ra::len-1, ra::len-1));
  142. }
  143. tr.section("iota len (var size)");
  144. {
  145. ra::Big<int, 3> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + (2 - ra::_2));
  146. // expr len is beatable and gives views.
  147. tr.test_eq(1, ra::size(a(ra::iota(ra::len), 0, 0).dimv));
  148. tr.test_eq(a(ra::iota(2), 0, 0), a(ra::iota(ra::len), 0, 0));
  149. tr.test_eq(a(0, ra::iota(3), 0), a(0, ra::iota(ra::len), 0));
  150. tr.test_eq(a(0, 0, ra::iota(4)), a(0, 0, ra::iota(ra::len)));
  151. // expr org is beatable and gives views.
  152. tr.test_eq(1, ra::size(a(ra::iota(ra::len, ra::len*0), 0, 0).dimv));
  153. tr.test_eq(a(ra::iota(2), 0, 0), a(ra::iota(ra::len, ra::len*0), 0, 0));
  154. tr.test_eq(a(0, ra::iota(3), 0), a(0, ra::iota(ra::len, ra::len*0), 0));
  155. tr.test_eq(a(0, 0, ra::iota(4)), a(0, 0, ra::iota(ra::len, ra::len*0)));
  156. // expr step is beatable.
  157. tr.test_eq(1, ra::size(a(0, 0, ra::iota(2, 0, ra::len/2)).dimv));
  158. tr.test_eq(a(0, 0, ra::iota(2, 0, 2)), a(0, 0, ra::iota(2, 0, ra::len/2)));
  159. }
  160. tr.section("iota len (static size) TBD");
  161. {
  162. }
  163. tr.section("beatable multi-axis selectors, var size");
  164. {
  165. static_assert(ra::beatable<ra::dots_t<0>>.rt, "dots_t<0> is beatable");
  166. auto test = [&tr](auto && a)
  167. {
  168. tr.info("a(ra::dots<0>, ...)").test_eq(a(0), a(ra::dots<0>, 0));
  169. tr.info("a(ra::dots<0>, ...)").test_eq(a(1), a(ra::dots<0>, 1));
  170. tr.info("a(ra::dots<1>, 0, ...)").test_eq(a(ra::all, 0), a(ra::dots<1>, 0));
  171. tr.info("a(ra::dots<1>, 1, ...)").test_eq(a(ra::all, 1), a(ra::dots<1>, 1));
  172. tr.info("a(ra::dots<2>, 0)").test_eq(a(ra::all, ra::all, 0), a(ra::dots<2>, 0));
  173. tr.info("a(ra::dots<2>, 1)").test_eq(a(ra::all, ra::all, 1), a(ra::dots<2>, 1));
  174. tr.info("a(ra::dots<2>, len-1)").test_eq(a(ra::all, ra::all, 3), a(ra::dots<2>, ra::len-1));
  175. tr.info("a(ra::dots<>, 1)").test_eq(a(ra::all, ra::all, 1), a(ra::dots<>, 1));
  176. tr.info("a(0)").test_eq(a(0, ra::all, ra::all), a(0));
  177. tr.info("a(1)").test_eq(a(1, ra::all, ra::all), a(1));
  178. tr.info("a(0, ra::dots<2>)").test_eq(a(0, ra::all, ra::all), a(0, ra::dots<2>));
  179. tr.info("a(1, ra::dots<2>)").test_eq(a(1, ra::all, ra::all), a(1, ra::dots<2>));
  180. tr.info("a(len-1, ra::dots<2>)").test_eq(a(1, ra::all, ra::all), a(ra::len-1, ra::dots<2>));
  181. tr.info("a(1, ra::dots<>)").test_eq(a(1, ra::all, ra::all), a(1, ra::dots<>));
  182. tr.info("a(0, ra::dots<>, 1)").test_eq(a(0, ra::all, 1), a(0, ra::dots<>, 1));
  183. tr.info("a(1, ra::dots<>, 0)").test_eq(a(1, ra::all, 0), a(1, ra::dots<>, 0));
  184. // cout << a(ra::dots<>, 1, ra::dots<>) << endl; // ct error
  185. };
  186. tr.section("fixed size");
  187. test(ra::Small<int, 2, 3, 4>(ra::_0*100 + ra::_1*10 + ra::_2));
  188. tr.section("fixed rank");
  189. test(ra::Big<int, 3>({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2));
  190. tr.section("var rank");
  191. test(ra::Big<int>({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2));
  192. }
  193. tr.section("insert, var size");
  194. {
  195. static_assert(ra::beatable<ra::insert_t<1>>.rt, "insert_t<1> is beatable");
  196. ra::Big<int, 3> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2);
  197. tr.info("a(ra::insert<0> ...)").test_eq(a(0), a(ra::insert<0>, 0));
  198. ra::Big<int, 4> a1({1, 2, 3, 4}, ra::_1*100 + ra::_2*10 + ra::_3);
  199. tr.info("a(ra::insert<1> ...)").test_eq(a1, a(ra::insert<1>));
  200. ra::Big<int, 4> a2({2, 1, 3, 4}, ra::_0*100 + ra::_2*10 + ra::_3);
  201. tr.info("a(ra::all, ra::insert<1>, ...)").test_eq(a2, a(ra::all, ra::insert<1>));
  202. ra::Big<int, 5> a3({2, 1, 1, 3, 4}, ra::_0*100 + ra::_3*10 + ra::_4);
  203. tr.info("a(ra::all, ra::insert<2>, ...)").test_eq(a3, a(ra::all, ra::insert<2>));
  204. tr.info("a(0, ra::insert<1>, ...)").test_eq(a1(ra::all, 0), a(0, ra::insert<1>));
  205. tr.info("a(ra::insert<1>, 0, ...)").test_eq(a1(ra::all, 0), a(ra::insert<1>, 0));
  206. ra::Big<int, 4> aa1({2, 2, 3, 4}, a(ra::insert<1>));
  207. tr.info("insert with undefined len 0").test_eq(a, aa1(0));
  208. tr.info("insert with undefined len 1").test_eq(a, aa1(1));
  209. }
  210. tr.section("insert, var rank");
  211. {
  212. static_assert(ra::beatable<ra::insert_t<1>>.rt, "insert_t<1> is beatable");
  213. ra::Big<int> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2);
  214. tr.info("a(ra::insert<0> ...)").test_eq(a(0), a(ra::insert<0>, 0));
  215. ra::Big<int> a1({1, 2, 3, 4}, ra::_1*100 + ra::_2*10 + ra::_3);
  216. tr.info("a(ra::insert<1> ...)").test_eq(a1, a(ra::insert<1>));
  217. ra::Big<int> a2({2, 1, 3, 4}, ra::_0*100 + ra::_2*10 + ra::_3);
  218. tr.info("a(ra::all, ra::insert<1>, ...)").test_eq(a2, a(ra::all, ra::insert<1>));
  219. ra::Big<int> a3({2, 1, 1, 3, 4}, ra::_0*100 + ra::_3*10 + ra::_4);
  220. tr.info("a(ra::all, ra::insert<2>, ...)").test_eq(a3, a(ra::all, ra::insert<2>));
  221. tr.info("a(0, ra::insert<1>, ...)").test_eq(a1(ra::all, 0), a(0, ra::insert<1>));
  222. tr.info("a(ra::insert<1>, 0, ...)").test_eq(a1(ra::all, 0), a(ra::insert<1>, 0));
  223. }
  224. tr.section("mix insert + dots");
  225. {
  226. static_assert(ra::beatable<ra::insert_t<1>>.rt, "insert_t<1> is beatable");
  227. auto test = [&tr](auto && a, auto && b)
  228. {
  229. tr.info("a(ra::insert<0>, ra::dots<3>)").test_eq(a(ra::insert<0>, ra::dots<3>), a(ra::insert<0>, ra::dots<>));
  230. tr.info("a(ra::insert<0>, ra::dots<1>, ...)").test_eq(a(ra::insert<0>, ra::all, ra::all, ra::all), a(ra::insert<0>, ra::dots<>));
  231. tr.info("a(ra::insert<0>, ra::dots<>)").test_eq(a(ra::insert<0>), a(ra::insert<0>, ra::dots<>));
  232. // add to something else to establish the size of the inserted axis.
  233. tr.info("a(ra::insert<1>, ra::dots<3>)")
  234. .test_eq(a(ra::insert<1>, ra::dots<3>) + ra::iota(2), a(ra::insert<1>, ra::dots<>) + ra::iota(2));
  235. tr.info("a(ra::insert<1>, ra::dots<1>, ...)")
  236. .test_eq(a(ra::insert<1>, ra::all, ra::all, ra::all) + ra::iota(2), a(ra::insert<1>, ra::dots<>) + ra::iota(2));
  237. tr.info("a(ra::insert<1>, ra::dots<>)").test_eq(a(ra::insert<1>) + ra::iota(2),
  238. a(ra::insert<1>, ra::dots<>) + ra::iota(2));
  239. // same on the back.
  240. tr.info("a(ra::dots<3>, ra::insert<1>)")
  241. .test_eq(b + a(ra::dots<3>, ra::insert<1>), b + a(ra::dots<>, ra::insert<1>));
  242. tr.info("a(ra::dots<1>, ..., ra::insert<1>)")
  243. .test_eq(b + a(ra::all, ra::all, ra::all, ra::insert<1>), b + a(ra::dots<>, ra::insert<1>));
  244. };
  245. tr.section("fixed rank");
  246. test(ra::Big<int, 3>({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2),
  247. ra::Big<int, 4>({2, 3, 4, 2}, ra::_0*100 + ra::_1*10 + ra::_2 + (2-ra::_3)));
  248. tr.section("var rank");
  249. test(ra::Big<int>({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2),
  250. ra::Big<int>({2, 3, 4, 2}, ra::_0*100 + ra::_1*10 + ra::_2 + (2-ra::_3)));
  251. }
  252. tr.section("shortcuts");
  253. {
  254. auto check_selection_shortcuts = [&tr](auto && a)
  255. {
  256. tr.info("a()").test_eq(Ureal<2>({4, 4}, ra::_0-ra::_1), a());
  257. tr.info("a(2, :)").test_eq(Ureal<1>({4}, 2-ra::_0), a(2, ra::all));
  258. tr.info("a(2)").test_eq(Ureal<1>({4}, 2-ra::_0), a(2));
  259. tr.info("a(:, 3)").test_eq(Ureal<1>({4}, ra::_0-3), a(ra::all, 3));
  260. tr.info("a(:, :)").test_eq(Ureal<2>({4, 4}, ra::_0-ra::_1), a(ra::all, ra::all));
  261. tr.info("a(:)").test_eq(Ureal<2>({4, 4}, ra::_0-ra::_1), a(ra::all));
  262. tr.info("a(1)").test_eq(Ureal<1>({4}, 1-ra::_0), a(1));
  263. tr.info("a(2, 2)").test_eq(0, a(2, 2));
  264. tr.info("a(0:2:, 0:2:)").test_eq(Ureal<2>({2, 2}, 2*(ra::_0-ra::_1)),
  265. a(ra::iota(2, 0, 2), ra::iota(2, 0, 2)));
  266. tr.info("a(1:2:, 0:2:)").test_eq(Ureal<2>({2, 2}, 2*ra::_0+1-2*ra::_1),
  267. a(ra::iota(2, 1, 2), ra::iota(2, 0, 2)));
  268. tr.info("a(0:2:, :)").test_eq(Ureal<2>({2, 4}, 2*ra::_0-ra::_1),
  269. a(ra::iota(2, 0, 2), ra::all));
  270. tr.info("a(0:2:)").test_eq(a(ra::iota(2, 0, 2), ra::all), a(ra::iota(2, 0, 2)));
  271. };
  272. check_selection_shortcuts(Ureal<2>({4, 4}, ra::_0-ra::_1));
  273. check_selection_shortcuts(Ureal<>({4, 4}, ra::_0-ra::_1));
  274. }
  275. return tr.summary();
  276. }