where.cc 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra/test - Tests for where() and pick().
  3. // (c) Daniel Llorens - 2014-2016
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #include <atomic>
  9. #include "ra/test.hh"
  10. using std::cout, std::endl, ra::TestRecorder;
  11. int main()
  12. {
  13. TestRecorder tr(std::cout);
  14. std::atomic<int> counter { 0 };
  15. auto count = [&counter](auto && x) -> decltype(auto) { ++counter; return x; };
  16. tr.section("pick");
  17. {
  18. ra::Small<double, 3> a0 = { 1, 2, 3 };
  19. ra::Small<double, 3> a1 = { 10, 20, 30 };
  20. ra::Small<int, 3> p = { 0, 1, 0 };
  21. ra::Small<double, 3> a(0.);
  22. counter = 0;
  23. a = pick(p, map(count, a0), map(count, a1));
  24. tr.test_eq(ra::Small<double, 3> { 1, 20, 3 }, a);
  25. tr.info("pick ETs execute only one branch per iteration").test_eq(3, int(counter));
  26. counter = 0;
  27. a = ra::where(p, map(count, a0), map(count, a1));
  28. tr.test_eq(ra::Small<double, 3> { 10, 2, 30 }, a);
  29. tr.info("where() is implemented using pick ET").test_eq(3, int(counter));
  30. }
  31. tr.section("write to pick");
  32. {
  33. ra::Small<double, 2> a0 = { 1, 2 };
  34. ra::Small<double, 2> a1 = { 10, 20 };
  35. ra::Small<int, 2> const p = { 0, 1 };
  36. ra::Small<double, 2> const a = { 7, 9 };
  37. counter = 0;
  38. pick(p, map(count, a0), map(count, a1)) = a;
  39. tr.test_eq(2, int(counter));
  40. tr.test_eq(ra::Small<double, 2> { 7, 2 }, a0);
  41. tr.test_eq(ra::Small<double, 2> { 10, 9 }, a1);
  42. tr.test_eq(ra::Small<double, 2> { 7, 9 }, a);
  43. tr.test_eq(ra::Small<int, 2> { 0, 1 }, p);
  44. }
  45. tr.section("pick works as any other array expression");
  46. {
  47. ra::Small<double, 2> a0 = { 1, 2 };
  48. ra::Small<double, 2> const a1 = { 10, 20 };
  49. ra::Small<int, 2> const p = { 0, 1 };
  50. ra::Small<double, 2> q = 3 + pick(p, a0, a1);
  51. tr.test_eq(ra::Small<int, 2> { 4, 23 }, q);
  52. }
  53. tr.section("pick with undefined len iota");
  54. {
  55. ra::Small<double, 2> a0 = { 1, 2 };
  56. ra::Small<double, 2> a1 = { 10, 20 };
  57. ra::Small<int, 2> const p = { 0, 1 };
  58. counter = 0;
  59. pick(p, map(count, a0), map(count, a1)) += ra::_0+5;
  60. tr.test_eq(2, int(counter));
  61. tr.test_eq(ra::Small<double, 2> { 6, 2 }, a0);
  62. tr.test_eq(ra::Small<double, 2> { 10, 26 }, a1);
  63. tr.test_eq(ra::Small<int, 2> { 0, 1 }, p);
  64. }
  65. tr.section("where, scalar W, array arguments in T/F");
  66. {
  67. std::array<double, 2> bb {1, 2};
  68. std::array<double, 2> cc {99, 99};
  69. auto b = ra::start(bb);
  70. auto c = ra::start(cc);
  71. cc[0] = cc[1] = 99;
  72. // pick_star
  73. c = ra::where(true, b, -b);
  74. tr.test_eq(1, cc[0]);
  75. tr.test_eq(2, cc[1]);
  76. // pick_at
  77. tr.test_eq(1, ra::where(true, b, -b).at(std::array {0}));
  78. // test against a bug where the op in where()'s Expr returned a dangling reference when both its args are rvalue refs. This was visible only at certain -O levels.
  79. cc[0] = cc[1] = 99;
  80. c = ra::where(true, b-3, -b);
  81. tr.test_eq(-2, cc[0]);
  82. tr.test_eq(-1, cc[1]);
  83. }
  84. tr.section("where as rvalue");
  85. {
  86. tr.test_eq(ra::Unique<int, 1> { 1, 2, 2, 1 }, ra::where(ra::Unique<bool, 1> { true, false, false, true }, 1, 2));
  87. tr.test_eq(ra::Unique<int, 1> { 17, 2, 3, 17 }
  88. , ra::where(ra::_0>0 && ra::_0<3, ra::Unique<int, 1> { 1, 2, 3, 4 }, 17));
  89. // [raop00] undef len iota returs value; so where()'s lambda must also return value.
  90. tr.test_eq(ra::Unique<int, 1> { 1, 2, 4, 7 }, ra::where(ra::Unique<bool, 1> { true, false, false, true }, 2*ra::_0+1, 2*ra::_0));
  91. // Using frame matching... TODO directly with ==expr?
  92. ra::Unique<int, 2> a({4, 3}, ra::_0-ra::_1);
  93. ra::Unique<int, 2> b = ra::where(ra::Unique<bool, 1> { true, false, false, true }, 99, a);
  94. tr.test_eq(ra::Unique<int, 2> ({4, 3}, { 99, 99, 99, 1, 0, -1, 2, 1, 0, 99, 99, 99 }), b);
  95. }
  96. tr.section("where nested");
  97. {
  98. {
  99. ra::Small<int, 3> a {-1, 0, 1};
  100. ra::Small<int, 3> b = ra::where(a>=0, ra::where(a<1, 77, 99), 44);
  101. tr.test_eq(ra::Small<int, 3> {44, 77, 99}, b);
  102. }
  103. {
  104. int a = 0;
  105. ra::Small<int, 2, 2> b = ra::where(a>=0, ra::where(a>=1, 99, 77), 44);
  106. tr.test_eq(ra::Small<int, 2, 2> {77, 77, 77, 77}, b);
  107. }
  108. }
  109. tr.section("where, scalar W, array arguments in T/F");
  110. {
  111. double a = 1./7;
  112. ra::Small<double, 2> b {1, 2};
  113. ra::Small<double, 2> c = ra::where(a>0, b, 3.);
  114. tr.test_eq(ra::Small<double, 2> {1, 2}, c);
  115. }
  116. tr.section("where as lvalue, scalar");
  117. {
  118. double a=0, b=0;
  119. bool w = true;
  120. ra::where(w, a, b) = 99;
  121. tr.test_eq(a, 99);
  122. tr.test_eq(b, 0);
  123. ra::where(!w, a, b) = 77;
  124. tr.test_eq(99, a);
  125. tr.test_eq(77, b);
  126. }
  127. tr.section("where, scalar + rank 0 array");
  128. {
  129. ra::Small<double> a { 33. };
  130. double b = 22.;
  131. tr.test_eq(33, double(ra::where(true, a, b)));
  132. tr.test_eq(22, double(ra::where(true, b, a)));
  133. }
  134. tr.section("where as lvalue, xpr [raop01]");
  135. {
  136. ra::Unique<int, 1> a { 0, 0, 0, 0 };
  137. ra::Unique<int, 1> b { 0, 0, 0, 0 };
  138. ra::where(ra::_0>0 && ra::_0<3, a, b) = 7;
  139. tr.test_eq(ra::Unique<int, 1> {0, 7, 7, 0}, a);
  140. tr.test_eq(ra::Unique<int, 1> {7, 0, 0, 7}, b);
  141. ra::where(ra::_0<=0 || ra::_0>=3, a, b) += 2;
  142. tr.test_eq(ra::Unique<int, 1> {2, 7, 7, 2}, a);
  143. tr.test_eq(ra::Unique<int, 1> {7, 2, 2, 7}, b);
  144. // Both must be lvalues; TODO check that either of these is an error.
  145. // ra::where(ra::_0>0 && ra::_0<3, ra::_0, a) = 99;
  146. // ra::where(ra::_0>0 && ra::_0<3, a, ra::_0) = 99;
  147. }
  148. tr.section("where with rvalue iota<n>(), fails to compile with g++ 5.2 -Os, gives wrong result with -O0");
  149. {
  150. tr.test_eq(ra::Small<int, 2> {0, 1}, ra::where(ra::Unique<bool, 1> { true, false }, ra::iota<0>(), ra::iota<0>()));
  151. tr.test_eq(ra::Unique<int, 1> { 0, 2 }, ra::where(ra::Unique<bool, 1> { true, false }, 3*ra::_0, 2*ra::_0));
  152. }
  153. tr.section("&& and || are short-circuiting");
  154. {
  155. using bool4 = ra::Small<bool, 4>;
  156. bool4 a {true, true, false, false}, b {true, false, true, false};
  157. int i = 0;
  158. tr.test_eq(bool4 {true, false, false, false}, a && map([&](auto && b) { ++i; return b; }, b));
  159. tr.info("short circuit test for &&").test_eq(2, i);
  160. i = 0;
  161. tr.test_eq(bool4 {true, true, true, false}, a || map([&](auto && b) { ++i; return b; }, b));
  162. tr.info("short circuit test for &&").test_eq(2, i);
  163. }
  164. // These tests should fail at compile time. No way to check them yet [ra42].
  165. // tr.section("size checks");
  166. // {
  167. // ra::Small<int, 3> a = { 1, 2, 3 };
  168. // ra::Small<int, 3> b = { 4, 5, 6 };
  169. // ra::Small<int, 2> c = 0; // ok if 2 -> 3; the test is for that case.
  170. // ra::where(a>b, a, c) += b;
  171. // tr.test_eq(ra::Small<int, 3> { 1, 2, 3 }, a);
  172. // tr.test_eq(ra::Small<int, 3> { 4, 5, 6 }, b);
  173. // }
  174. return tr.summary();
  175. }