expr.hh 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Expression templates with prefix matching.
  3. // (c) Daniel Llorens - 2011-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "atom.hh"
  10. #include <functional>
  11. namespace ra {
  12. // --------------------
  13. // prefix match
  14. // --------------------
  15. constexpr rank_t
  16. choose_rank(rank_t ra, rank_t rb) { return BAD==rb ? ra : BAD==ra ? rb : ANY==ra ? ra : ANY==rb ? rb : std::max(ra, rb); }
  17. // pick first if mismatch (see below). FIXME maybe return invalid.
  18. constexpr dim_t
  19. choose_len(dim_t sa, dim_t sb) { return BAD==sa ? sb : BAD==sb ? sa : ANY==sa ? sb : sa; }
  20. template <bool checkp, class T, class K=mp::iota<mp::len<T>>> struct Match;
  21. template <bool checkp, IteratorConcept ... P, int ... I>
  22. struct Match<checkp, std::tuple<P ...>, mp::int_list<I ...>>
  23. {
  24. std::tuple<P ...> t;
  25. // rank of largest subexpr
  26. constexpr static rank_t rs = [] { rank_t r=BAD; return ((r=choose_rank(r, ra::rank_s<P>())), ...); }();
  27. // 0: fail, 1: rt, 2: pass
  28. consteval static int
  29. check_s()
  30. {
  31. if constexpr (sizeof...(P)<2) {
  32. return 2;
  33. } else if constexpr (ANY==rs) {
  34. return 1; // FIXME can be tightened to 2 if all args are rank 0 save one
  35. } else {
  36. bool tbc = false;
  37. for (int k=0; k<rs; ++k) {
  38. dim_t ls = len_s(k);
  39. if (((k<ra::rank_s<P>() && ls!=choose_len(std::decay_t<P>::len_s(k), ls)) || ...)) {
  40. return 0;
  41. } else {
  42. int anyk = ((k<ra::rank_s<P>() && (ANY==std::decay_t<P>::len_s(k))) + ...);
  43. int fixk = ((k<ra::rank_s<P>() && (0<=std::decay_t<P>::len_s(k))) + ...);
  44. tbc = tbc || (anyk>0 && anyk+fixk>1);
  45. }
  46. }
  47. return tbc ? 1 : 2;
  48. }
  49. }
  50. constexpr bool
  51. check() const
  52. {
  53. if constexpr (sizeof...(P)<2) {
  54. return true;
  55. } else if constexpr (constexpr int c = check_s(); 0==c) {
  56. return false;
  57. } else if constexpr (1==c) {
  58. for (int k=0; k<rank(); ++k) {
  59. dim_t ls = len(k);
  60. if (((k<ra::rank(std::get<I>(t)) && ls!=choose_len(std::get<I>(t).len(k), ls)) || ...)) {
  61. RA_CHECK(!checkp, "Shape mismatch on axis ", k, " [", (std::array { std::get<I>(t).len(k) ... }), "].");
  62. return false;
  63. }
  64. }
  65. }
  66. return true;
  67. }
  68. constexpr
  69. Match(P ... p_): t(p_ ...) // [ra1]
  70. {
  71. // TODO Maybe on ply, would make checkp unnecessary, make agree_xxx() unnecessary.
  72. if constexpr (checkp && !(has_len<P> || ...)) {
  73. static_assert(check_s(), "Shape mismatch.");
  74. RA_CHECK(check());
  75. }
  76. }
  77. consteval static rank_t
  78. rank() requires (ANY!=rs)
  79. {
  80. return rs;
  81. }
  82. constexpr rank_t
  83. rank() const requires (ANY==rs)
  84. {
  85. rank_t r = BAD;
  86. ((r = choose_rank(r, ra::rank(std::get<I>(t)))), ...);
  87. assert(ANY!=r); // not at runtime
  88. return r;
  89. }
  90. // first nonnegative size, if none first ANY, if none then BAD
  91. constexpr static dim_t
  92. len_s(int k)
  93. {
  94. auto f = [&k]<class A>(dim_t s) {
  95. constexpr rank_t ar = ra::rank_s<A>();
  96. return (ar<0 || k<ar) ? choose_len(s, A::len_s(k)) : s;
  97. };
  98. dim_t s = BAD; ((s>=0 ? s : s = f.template operator()<std::decay_t<P>>(s)), ...);
  99. return s;
  100. }
  101. constexpr static dim_t
  102. len(int k) requires (requires (int kk) { P::len(kk); } && ...)
  103. {
  104. return len_s(k);
  105. }
  106. constexpr dim_t
  107. len(int k) const requires (!(requires (int kk) { P::len(kk); } && ...))
  108. {
  109. auto f = [&k](dim_t s, auto const & a) {
  110. return k<ra::rank(a) ? choose_len(s, a.len(k)) : s;
  111. };
  112. dim_t s = BAD; ((s>=0 ? s : s = f(s, std::get<I>(t))), ...);
  113. assert(ANY!=s); // not at runtime
  114. return s;
  115. }
  116. constexpr auto
  117. step(int i) const
  118. {
  119. return std::make_tuple(std::get<I>(t).step(i) ...);
  120. }
  121. constexpr void
  122. adv(rank_t k, dim_t d)
  123. {
  124. (std::get<I>(t).adv(k, d), ...);
  125. }
  126. constexpr bool
  127. keep_step(dim_t st, int z, int j) const
  128. requires (!(requires (dim_t st, rank_t z, rank_t j) { P::keep_step(st, z, j); } && ...))
  129. {
  130. return (std::get<I>(t).keep_step(st, z, j) && ...);
  131. }
  132. constexpr static bool
  133. keep_step(dim_t st, int z, int j)
  134. requires (requires (dim_t st, rank_t z, rank_t j) { P::keep_step(st, z, j); } && ...)
  135. {
  136. return (std::decay_t<P>::keep_step(st, z, j) && ...);
  137. }
  138. constexpr auto save() const { return std::make_tuple(std::get<I>(t).save() ...); }
  139. template <class PP> constexpr void load(PP const & pp) { ((std::get<I>(t).load(std::get<I>(pp))), ...); }
  140. template <class S> constexpr void mov(S const & s) { ((std::get<I>(t).mov(std::get<I>(s))), ...); }
  141. };
  142. // ---------------------------
  143. // reframe
  144. // ---------------------------
  145. // Transpose variant for IteratorConcepts. As in transpose(), one names the destination axis for
  146. // each original axis. However, axes may not be repeated. Used in the rank conjunction below.
  147. template <dim_t N, class T> constexpr T samestep = N;
  148. template <dim_t N, class ... T> constexpr std::tuple<T ...> samestep<N, std::tuple<T ...>> = { samestep<N, T> ... };
  149. // Dest is a list of destination axes [l0 l1 ... li ... l(rank(A)-1)].
  150. // The dimensions of the reframed A are numbered as [0 ... k ... max(l)-1].
  151. // If li = k for some i, then axis k of the reframed A moves on axis i of the original iterator A.
  152. // If not, then axis k of the reframed A is 'dead' and doesn't move the iterator.
  153. // TODO invalid for ANY, since Dest is compile time. [ra7]
  154. template <class Dest, IteratorConcept A>
  155. struct Reframe
  156. {
  157. A a;
  158. constexpr static int orig(int k) { return mp::int_list_index<Dest>(k); }
  159. consteval static rank_t rank() { return 1+mp::fold<mp::max, ic_t<-1>, Dest>::value; }
  160. constexpr static dim_t len_s(int k)
  161. {
  162. int l=orig(k);
  163. return l>=0 ? std::decay_t<A>::len_s(l) : BAD;
  164. }
  165. constexpr dim_t
  166. len(int k) const
  167. {
  168. int l=orig(k);
  169. return l>=0 ? a.len(l) : BAD;
  170. }
  171. constexpr auto
  172. step(int k) const
  173. {
  174. int l=orig(k);
  175. return l>=0 ? a.step(l) : samestep<0, decltype(a.step(l))>;
  176. }
  177. constexpr void
  178. adv(rank_t k, dim_t d)
  179. {
  180. int l=orig(k);
  181. if (l>=0) { a.adv(l, d); }
  182. }
  183. constexpr bool
  184. keep_step(dim_t st, int z, int j) const
  185. {
  186. int wz=orig(z), wj=orig(j);
  187. return wz>=0 && wj>=0 && a.keep_step(st, wz, wj);
  188. }
  189. constexpr decltype(auto)
  190. at(auto const & i) const
  191. {
  192. return a.at(mp::map_indices<dim_t, Dest>(i));
  193. }
  194. constexpr decltype(auto) operator*() const { return *a; }
  195. constexpr auto save() const { return a.save(); }
  196. template <class P> constexpr void load(P const & p) { a.load(p); }
  197. // FIXME only if Dest preserves axis order, which is how wrank works, but this limitation should be explicit.
  198. template <class S> constexpr void mov(S const & s) { a.mov(s); }
  199. };
  200. // Optimize no-op case. TODO If A is CellBig, etc. beat Dest on it, same for eventual transpose_expr<>.
  201. template <class Dest, class A>
  202. constexpr decltype(auto)
  203. reframe(A && a)
  204. {
  205. if constexpr (std::is_same_v<Dest, mp::iota<1+mp::fold<mp::max, ic_t<-1>, Dest>::value>>) {
  206. return RA_FWD(a);
  207. } else {
  208. return Reframe<Dest, A> { RA_FWD(a) };
  209. }
  210. }
  211. // ---------------------------
  212. // verbs and rank conjunction
  213. // ---------------------------
  214. template <class cranks_, class Op_>
  215. struct Verb
  216. {
  217. using cranks = cranks_;
  218. using Op = Op_;
  219. Op op;
  220. };
  221. RA_IS_DEF(is_verb, (std::is_same_v<A, Verb<typename A::cranks, typename A::Op>>))
  222. template <class cranks, class Op>
  223. constexpr auto
  224. wrank(cranks cranks_, Op && op) { return Verb<cranks, Op> { RA_FWD(op) }; }
  225. template <rank_t ... crank, class Op>
  226. constexpr auto
  227. wrank(Op && op) { return Verb<mp::int_list<crank ...>, Op> { RA_FWD(op) }; }
  228. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  229. struct Framematch_def;
  230. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  231. using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
  232. template <class A, class B>
  233. struct max_i
  234. {
  235. constexpr static int value = (A::value == choose_rank(A::value, B::value)) ? 0 : 1;
  236. };
  237. // Get a list (per argument) of lists of live axes. The last frame match is handled by standard prefix matching.
  238. template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
  239. struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  240. {
  241. static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  242. // live = number of live axes on this frame, for each argument. // TODO crank negative, inf.
  243. using live = mp::int_list<(rank_s<Ti>() - mp::len<Ri> - crank::value) ...>;
  244. using frameaxes = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri> - crank::value), skip>> ...>;
  245. using FM = Framematch<W, std::tuple<Ti ...>, frameaxes, skip + mp::ref<live, mp::indexof<max_i, live>>::value>;
  246. using R = typename FM::R;
  247. template <class VV> constexpr static decltype(auto) op(VV && v) { return FM::op(RA_FWD(v).op); } // cf [ra31]
  248. };
  249. // Terminal case where V doesn't have rank (is a raw op()).
  250. template <class V, class ... Ti, class ... Ri, rank_t skip>
  251. struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  252. {
  253. static_assert(sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  254. // TODO -crank::value when the actual verb rank is used (eg to use CellBig<... that_rank> instead of just begin()).
  255. using R = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri>), skip>> ...>;
  256. template <class VV> constexpr static decltype(auto) op(VV && v) { return RA_FWD(v); }
  257. };
  258. // ---------------
  259. // explicit agreement checks
  260. // ---------------
  261. template <class ... P>
  262. constexpr bool
  263. agree(P && ... p) { return agree_(ra::start(RA_FWD(p)) ...); }
  264. // 0: fail, 1: rt, 2: pass
  265. template <class ... P>
  266. constexpr int
  267. agree_s(P && ... p) { return agree_s_(ra::start(RA_FWD(p)) ...); }
  268. template <class Op, class ... P> requires (is_verb<Op>)
  269. constexpr bool
  270. agree_op(Op && op, P && ... p) { return agree_verb(mp::iota<sizeof...(P)> {}, RA_FWD(op), RA_FWD(p) ...); }
  271. template <class Op, class ... P> requires (!is_verb<Op>)
  272. constexpr bool
  273. agree_op(Op && op, P && ... p) { return agree(RA_FWD(p) ...); }
  274. template <class ... P>
  275. constexpr bool
  276. agree_(P && ... p) { return (Match<false, std::tuple<P ...>> { RA_FWD(p) ... }).check(); }
  277. template <class ... P>
  278. constexpr int
  279. agree_s_(P && ... p) { return Match<false, std::tuple<P ...>>::check_s(); }
  280. template <class V, class ... T, int ... i>
  281. constexpr bool
  282. agree_verb(mp::int_list<i ...>, V && v, T && ... t)
  283. {
  284. using FM = Framematch<V, std::tuple<T ...>>;
  285. return agree_op(FM::op(RA_FWD(v)), reframe<mp::ref<typename FM::R, i>>(ra::start(RA_FWD(t))) ...);
  286. }
  287. // ---------------------------
  288. // operator expression
  289. // ---------------------------
  290. template <class Op, class T, class K=mp::iota<mp::len<T>>> struct Expr;
  291. template <class Op, IteratorConcept ... P, int ... I>
  292. struct Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  293. {
  294. using Match_ = Match<true, std::tuple<P ...>>;
  295. using Match_::t, Match_::rs, Match_::rank;
  296. Op op;
  297. constexpr Expr(Op op_, P ... p_): Match_(p_ ...), op(op_) {} // [ra1]
  298. RA_DEF_ASSIGNOPS_SELF(Expr)
  299. RA_DEF_ASSIGNOPS_DEFAULT_SET
  300. constexpr decltype(auto) at(auto const & j) const { return std::invoke(op, std::get<I>(t).at(j) ...); }
  301. constexpr decltype(auto) operator*() const { return std::invoke(op, *std::get<I>(t) ...); }
  302. // needed for rs==ANY, which don't decay to scalar when used as operator arguments.
  303. constexpr
  304. operator decltype(std::invoke(op, *std::get<I>(t) ...)) () const
  305. {
  306. if constexpr (0!=rs && (1!=rs || 1!=size_s<Expr>())) { // for coord types; so ct only
  307. static_assert(rs==ANY);
  308. RA_CHECK(0==rank(), "Bad scalar conversion from shape [", ra::noshape, ra::shape(*this), "].");
  309. }
  310. return *(*this);
  311. }
  312. };
  313. template <class Op, IteratorConcept ... P>
  314. constexpr bool is_special_def<Expr<Op, std::tuple<P ...>>> = (is_special<P> || ...);
  315. template <class V, class ... T, int ... i>
  316. constexpr auto
  317. expr_verb(mp::int_list<i ...>, V && v, T && ... t)
  318. {
  319. using FM = Framematch<V, std::tuple<T ...>>;
  320. return expr(FM::op(RA_FWD(v)), reframe<mp::ref<typename FM::R, i>>(RA_FWD(t)) ...);
  321. }
  322. template <class Op, class ... P>
  323. constexpr auto
  324. expr(Op && op, P && ... p)
  325. {
  326. if constexpr (is_verb<Op>) {
  327. return expr_verb(mp::iota<sizeof...(P)> {}, RA_FWD(op), RA_FWD(p) ...);
  328. } else {
  329. return Expr<Op, std::tuple<P ...>> { RA_FWD(op), RA_FWD(p) ... };
  330. }
  331. }
  332. template <class Op, class ... A>
  333. constexpr auto
  334. map(Op && op, A && ... a)
  335. {
  336. return expr(RA_FWD(op), start(RA_FWD(a)) ...);
  337. }
  338. // ---------------------------
  339. // pick
  340. // ---------------------------
  341. template <class T, class J> struct pick_at_type;
  342. template <class ... P, class J> struct pick_at_type<std::tuple<P ...>, J>
  343. {
  344. using type = std::common_reference_t<decltype(std::declval<P>().at(std::declval<J>())) ...>;
  345. };
  346. template <std::size_t I, class T, class J>
  347. constexpr pick_at_type<mp::drop1<std::decay_t<T>>, J>::type
  348. pick_at(std::size_t p0, T && t, J const & j)
  349. {
  350. constexpr std::size_t N = mp::len<std::decay_t<T>> - 1;
  351. if constexpr (I < N) {
  352. return (p0==I) ? std::get<I+1>(t).at(j) : pick_at<I+1>(p0, t, j);
  353. } else {
  354. RA_CHECK(p0 < N, "Bad pick ", p0, " with ", N, " arguments."); std::abort();
  355. }
  356. }
  357. template <class T> struct pick_star_type;
  358. template <class ... P> struct pick_star_type<std::tuple<P ...>>
  359. {
  360. using type = std::common_reference_t<decltype(*std::declval<P>()) ...>;
  361. };
  362. template <std::size_t I, class T>
  363. constexpr pick_star_type<mp::drop1<std::decay_t<T>>>::type
  364. pick_star(std::size_t p0, T && t)
  365. {
  366. constexpr std::size_t N = mp::len<std::decay_t<T>> - 1;
  367. if constexpr (I < N) {
  368. return (p0==I) ? *(std::get<I+1>(t)) : pick_star<I+1>(p0, t);
  369. } else {
  370. RA_CHECK(p0 < N, "Bad pick ", p0, " with ", N, " arguments."); std::abort();
  371. }
  372. }
  373. template <class T, class K=mp::iota<mp::len<T>>> struct Pick;
  374. template <IteratorConcept ... P, int ... I>
  375. struct Pick<std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  376. {
  377. using Match_ = Match<true, std::tuple<P ...>>;
  378. using Match_::t, Match_::rs, Match_::rank;
  379. static_assert(sizeof...(P)>1);
  380. constexpr Pick(P ... p_): Match_(p_ ...) {} // [ra1]
  381. RA_DEF_ASSIGNOPS_SELF(Pick)
  382. RA_DEF_ASSIGNOPS_DEFAULT_SET
  383. constexpr decltype(auto) at(auto const & j) const { return pick_at<0>(std::get<0>(t).at(j), t, j); }
  384. constexpr decltype(auto) operator*() const { return pick_star<0>(*std::get<0>(t), t); }
  385. // needed for xpr with rs==ANY, which don't decay to scalar when used as operator arguments.
  386. constexpr
  387. operator decltype(pick_star<0>(*std::get<0>(t), t)) () const
  388. {
  389. if constexpr (0!=rs && (1!=rs || 1!=size_s<Pick>())) { // for coord types; so ct only
  390. static_assert(rs==ANY);
  391. RA_CHECK(0==rank(), "Bad scalar conversion from shape [", ra::noshape, ra::shape(*this), "].");
  392. }
  393. return *(*this);
  394. }
  395. };
  396. template <IteratorConcept ... P>
  397. constexpr bool is_special_def<Pick<std::tuple<P ...>>> = (is_special<P> || ...);
  398. template <class ... P> Pick(P && ... p) -> Pick<std::tuple<P ...>>;
  399. template <class ... P> constexpr auto
  400. pick(P && ... p) { return Pick { start(RA_FWD(p)) ... }; }
  401. } // namespace ra