expr.hh 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Expression templates with prefix matching.
  3. // (c) Daniel Llorens - 2011-2024
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include <cassert>
  10. #include <functional>
  11. #include "base.hh"
  12. // --------------------
  13. // error handling. See examples/throw.cc for how to customize.
  14. // --------------------
  15. #if !defined(RA_DO_CHECK)
  16. #define RA_DO_CHECK 1 // tell users
  17. #endif
  18. #if RA_DO_CHECK==0
  19. #define RA_CHECK(...) // good luck
  20. #else
  21. #ifdef RA_ASSERT
  22. #define RA_CHECK(...) RA_ASSERT(__VA_ARGS__)
  23. #elif RA_DO_CHECK==1
  24. #define RA_CHECK(cond, ...) { assert(cond); }
  25. #elif RA_DO_CHECK==2
  26. #include <iostream>
  27. #define RA_CHECK(cond, ...) \
  28. { \
  29. if (std::is_constant_evaluated()) { \
  30. assert(cond /* FIXME show args */); \
  31. } else { \
  32. if (!(cond)) [[unlikely]] { \
  33. std::cerr << ra::format("*** ra::", std::source_location::current(), " (" STRINGIZE(cond) ") " __VA_OPT__(,) __VA_ARGS__, " ***") << std::endl; \
  34. std::abort(); \
  35. } \
  36. } \
  37. }
  38. #else
  39. #error Bad value for RA_DO_CHECK
  40. #endif
  41. #endif
  42. #define RA_AFTER_CHECK Yes
  43. namespace ra {
  44. constexpr bool inside(dim_t i, dim_t b) { return 0<=i && i<b; }
  45. // --------------------
  46. // assign ops for settable iterators. Might be different for e.g. Views.
  47. // --------------------
  48. // Forward to forbid misusing value y as ref [ra5].
  49. #define RA_ASSIGNOPS_LINE(OP) \
  50. for_each([](auto && y, auto && x) { RA_FWD(y) OP x; }, *this, RA_FWD(x))
  51. #define RA_ASSIGNOPS(OP) \
  52. constexpr void operator OP(auto && x) { RA_ASSIGNOPS_LINE(OP); }
  53. // But see local ASSIGNOPS elsewhere.
  54. #define RA_ASSIGNOPS_DEFAULT_SET \
  55. FOR_EACH(RA_ASSIGNOPS, =, *=, +=, -=, /=)
  56. // Restate for expression classes since a template doesn't replace the copy assignment op.
  57. #define RA_ASSIGNOPS_SELF(TYPE) \
  58. TYPE & operator=(TYPE && x) { RA_ASSIGNOPS_LINE(=); return *this; } \
  59. TYPE & operator=(TYPE const & x) { RA_ASSIGNOPS_LINE(=); return *this; } \
  60. constexpr TYPE(TYPE && x) = default; \
  61. constexpr TYPE(TYPE const & x) = default;
  62. // --------------------
  63. // terminal types
  64. // --------------------
  65. // Rank-0 IteratorConcept. Can be used on foreign objects, or as alternative to the rank conjunction.
  66. // We still want f(scalar(C)) to be f(C) and not map(f, C), this is controlled by tomap/toreduce.
  67. template <class C>
  68. struct Scalar
  69. {
  70. C c;
  71. RA_ASSIGNOPS_DEFAULT_SET
  72. consteval static rank_t rank() { return 0; }
  73. constexpr static dim_t len_s(int k) { std::abort(); }
  74. constexpr static dim_t len(int k) { std::abort(); }
  75. constexpr static dim_t step(int k) { return 0; }
  76. constexpr static void adv(rank_t k, dim_t d) {}
  77. constexpr static bool keep_step(dim_t st, int z, int j) { return true; }
  78. constexpr decltype(auto) at(auto && j) const { return c; }
  79. constexpr C & operator*() requires (std::is_lvalue_reference_v<C>) { return c; } // [ra37]
  80. constexpr C const & operator*() requires (!std::is_lvalue_reference_v<C>) { return c; }
  81. constexpr C const & operator*() const { return c; } // [ra39]
  82. constexpr static int save() { return 0; }
  83. constexpr static void load(int) {}
  84. constexpr static void mov(dim_t d) {}
  85. };
  86. template <class C> constexpr auto
  87. scalar(C && c) { return Scalar<C> { RA_FWD(c) }; }
  88. template <class N> constexpr int
  89. maybe_any = []{
  90. if constexpr (is_constant<N>) {
  91. return N::value;
  92. } else {
  93. static_assert(std::is_integral_v<N> || !std::is_same_v<N, bool>);
  94. return ANY;
  95. }
  96. }();
  97. // IteratorConcept for foreign rank 1 objects.
  98. template <std::bidirectional_iterator I, class N, class S>
  99. struct Ptr
  100. {
  101. static_assert(is_constant<N> || 0==rank_s<N>());
  102. static_assert(is_constant<S> || 0==rank_s<S>());
  103. constexpr static dim_t nn = maybe_any<N>;
  104. static_assert(nn==ANY || nn>=0 || nn==BAD);
  105. constexpr static bool constant = is_constant<N> && is_constant<S>;
  106. I i;
  107. [[no_unique_address]] N const n = {};
  108. [[no_unique_address]] S const s = {};
  109. constexpr Ptr(I i, N n, S s): i(i), n(n), s(s) {}
  110. RA_ASSIGNOPS_SELF(Ptr)
  111. RA_ASSIGNOPS_DEFAULT_SET
  112. consteval static rank_t rank() { return 1; }
  113. constexpr static dim_t len_s(int k) { return nn; } // len(k==0) or step(k>=0)
  114. constexpr static dim_t len(int k) requires (is_constant<N>) { return len_s(k); }
  115. constexpr dim_t len(int k) const requires (!is_constant<N>) { return n; }
  116. constexpr static dim_t step(int k) { return k==0 ? 1 : 0; }
  117. constexpr void adv(rank_t k, dim_t d) { i += step(k) * d * s; }
  118. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  119. constexpr decltype(auto) at(auto && j) const requires (std::random_access_iterator<I>)
  120. {
  121. RA_CHECK(BAD==nn || inside(j[0], n), "Bad index ", j[0], " for len[0]=", n, ".");
  122. return i[j[0]*s];
  123. }
  124. constexpr decltype(auto) operator*() const { return *i; }
  125. constexpr auto save() const { return i; }
  126. constexpr void load(I ii) { i = ii; }
  127. constexpr void mov(dim_t d) { i += d*s; }
  128. };
  129. template <class X> using seq_arg = std::conditional_t<is_constant<std::decay_t<X>> || is_scalar<std::decay_t<X>>, std::decay_t<X>, X>;
  130. template <class S>
  131. constexpr auto
  132. thestep()
  133. {
  134. if constexpr (std::is_integral_v<S>) {
  135. return S(1);
  136. } else if constexpr (is_constant<S>) {
  137. static_assert(1==S::value);
  138. return S {};
  139. } else {
  140. static_assert(always_false<S>, "Bad step type for sequence.");
  141. }
  142. }
  143. template <class I, class N=dim_c<BAD>, class S=dim_c<1>>
  144. constexpr auto
  145. ptr(I && i, N && n = N {}, S && s = thestep<S>())
  146. {
  147. if constexpr (std::ranges::bidirectional_range<std::remove_reference_t<I>>) {
  148. static_assert(std::is_same_v<dim_c<BAD>, N>, "Object has own length.");
  149. static_assert(std::is_same_v<dim_c<1>, S>, "No step with deduced size.");
  150. if constexpr (ANY==size_s<I>()) {
  151. return ptr(std::begin(RA_FWD(i)), std::ssize(i), RA_FWD(s));
  152. } else {
  153. return ptr(std::begin(RA_FWD(i)), ic<size_s<I>()>, RA_FWD(s));
  154. }
  155. } else if constexpr (std::bidirectional_iterator<std::decay_t<I>>) {
  156. if constexpr (std::is_integral_v<N>) {
  157. RA_CHECK(n>=0, "Bad ptr length ", n, ".");
  158. }
  159. return Ptr<std::decay_t<I>, seq_arg<N>, seq_arg<S>> { i, RA_FWD(n), RA_FWD(s) };
  160. } else {
  161. static_assert(always_false<I>, "Bad type for ptr().");
  162. }
  163. }
  164. // Sequence and IteratorConcept for same. Iota isn't really a terminal, but its exprs must all have rank 0.
  165. // FIXME w is a custom Reframe mechanism inherited from TensorIndex. Generalize/unify
  166. // FIXME Sequence should be its own type, we can't represent a ct origin bc IteratorConcept interface takes up i.
  167. template <int w, class I, class N, class S>
  168. struct Iota
  169. {
  170. static_assert(w>=0);
  171. static_assert(is_constant<S> || 0==rank_s<S>());
  172. static_assert(is_constant<N> || 0==rank_s<N>());
  173. constexpr static dim_t nn = maybe_any<N>;
  174. static_assert(nn==ANY || nn>=0 || nn==BAD);
  175. constexpr static bool constant = is_constant<N> && is_constant<S>;
  176. I i = {};
  177. [[no_unique_address]] N const n = {};
  178. [[no_unique_address]] S const s = {};
  179. constexpr static S gets() requires (is_constant<S>) { return S {}; }
  180. constexpr I gets() const requires (!is_constant<S>) { return s; }
  181. consteval static rank_t rank() { return w+1; }
  182. constexpr static dim_t len_s(int k) { return k==w ? nn : BAD; } // len(0<=k<=w) or step(0<=k)
  183. constexpr static dim_t len(int k) requires (is_constant<N>) { return len_s(k); }
  184. constexpr dim_t len(int k) const requires (!is_constant<N>) { return k==w ? n : BAD; }
  185. constexpr static dim_t step(rank_t k) { return k==w ? 1 : 0; }
  186. constexpr void adv(rank_t k, dim_t d) { i += I(step(k) * d) * I(s); }
  187. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  188. constexpr auto at(auto && j) const
  189. {
  190. RA_CHECK(BAD==nn || inside(j[0], n), "Bad index ", j[0], " for len[0]=", n, ".");
  191. return i + I(j[w])*I(s);
  192. }
  193. constexpr I operator*() const { return i; }
  194. constexpr I save() const { return i; }
  195. constexpr void load(I ii) { i = ii; }
  196. constexpr void mov(dim_t d) { i += I(d)*I(s); }
  197. };
  198. template <int w=0, class I=dim_t, class N=dim_c<BAD>, class S=dim_c<1>>
  199. constexpr auto
  200. iota(N && n = N {}, I && i = 0, S && s = thestep<S>())
  201. {
  202. if constexpr (std::is_integral_v<N>) {
  203. RA_CHECK(n>=0, "Bad iota length ", n, ".");
  204. }
  205. return Iota<w, seq_arg<I>, seq_arg<N>, seq_arg<S>> { RA_FWD(i), RA_FWD(n), RA_FWD(s) };
  206. }
  207. #define DEF_TENSORINDEX(w) constexpr auto JOIN(_, w) = iota<w>();
  208. FOR_EACH(DEF_TENSORINDEX, 0, 1, 2, 3, 4);
  209. #undef DEF_TENSORINDEX
  210. RA_IS_DEF(is_iota, false)
  211. // BAD is excluded from beating to allow B = A(... i ...) to use B's len. FIXME find a way?
  212. template <class I, class N, class S>
  213. constexpr bool is_iota_def<Iota<0, I, N, S>> = (BAD != Iota<0, I, N, S>::nn);
  214. constexpr bool
  215. inside(is_iota auto const & i, dim_t l)
  216. {
  217. return (inside(i.i, l) && inside(i.i+(i.n-1)*i.s, l)) || (0==i.n /* don't bother */);
  218. }
  219. constexpr struct Len
  220. {
  221. consteval static rank_t rank() { return 0; }
  222. constexpr static dim_t len_s(int k) { std::abort(); }
  223. constexpr static dim_t len(int k) { std::abort(); }
  224. constexpr static dim_t step(int k) { std::abort(); }
  225. constexpr static void adv(rank_t k, dim_t d) { std::abort(); }
  226. constexpr static bool keep_step(dim_t st, int z, int j) { std::abort(); }
  227. constexpr dim_t operator*() const { std::abort(); }
  228. constexpr static int save() { std::abort(); }
  229. constexpr static void load(int) { std::abort(); }
  230. constexpr static void mov(dim_t d) { std::abort(); }
  231. } len;
  232. // protect exprs with Len from reduction.
  233. template <> constexpr bool is_special_def<Len> = true;
  234. RA_IS_DEF(has_len, false);
  235. // --------------
  236. // making Iterators
  237. // --------------
  238. // TODO arbitrary exprs? runtime cr? ra::len in cr?
  239. template <int cr>
  240. constexpr auto
  241. iter(SliceConcept auto && a) { return RA_FWD(a).template iter<cr>(); }
  242. constexpr void
  243. start(auto && t) { static_assert(always_false<decltype(t)>, "Cannot start() type."); }
  244. constexpr auto
  245. start(is_fov auto && t) { return ra::ptr(RA_FWD(t)); }
  246. template <class T>
  247. constexpr auto
  248. start(std::initializer_list<T> v) { return ra::ptr(v.begin(), v.size()); }
  249. constexpr auto
  250. start(is_scalar auto && t) { return ra::scalar(RA_FWD(t)); }
  251. // forward declare for Match; implemented in small.hh.
  252. constexpr auto
  253. start(is_builtin_array auto && t);
  254. // neither CellBig nor CellSmall will retain rvalues [ra4].
  255. constexpr auto
  256. start(SliceConcept auto && t) { return iter<0>(RA_FWD(t)); }
  257. RA_IS_DEF(is_ra_scalar, (std::same_as<A, Scalar<decltype(std::declval<A>().c)>>))
  258. // iterators need to be start()ed on each use [ra35].
  259. template <class T> requires (is_iterator<T> && !is_ra_scalar<T>)
  260. constexpr auto
  261. start(T & t) { return t; }
  262. // FIXME const Iterator would still be unusable after start().
  263. constexpr decltype(auto)
  264. start(is_iterator auto && t) { return RA_FWD(t); }
  265. // --------------------
  266. // prefix match
  267. // --------------------
  268. constexpr rank_t
  269. choose_rank(rank_t ra, rank_t rb) { return BAD==rb ? ra : BAD==ra ? rb : ANY==ra ? ra : ANY==rb ? rb : std::max(ra, rb); }
  270. // pick first if mismatch (see below). FIXME maybe return invalid.
  271. constexpr dim_t
  272. choose_len(dim_t sa, dim_t sb) { return BAD==sa ? sb : BAD==sb ? sa : ANY==sa ? sb : sa; }
  273. template <bool checkp, class T, class K=mp::iota<mp::len<T>>> struct Match;
  274. template <bool checkp, IteratorConcept ... P, int ... I>
  275. struct Match<checkp, std::tuple<P ...>, mp::int_list<I ...>>
  276. {
  277. std::tuple<P ...> t;
  278. constexpr static rank_t rs = [] { rank_t r=BAD; return ((r=choose_rank(r, ra::rank_s<P>())), ...); }();
  279. // 0: fail, 1: rt, 2: pass
  280. consteval static int
  281. check_s()
  282. {
  283. if constexpr (sizeof...(P)<2) {
  284. return 2;
  285. } else if constexpr (ANY==rs) {
  286. return 1; // FIXME can be tightened to 2 if all args are rank 0 save one
  287. } else {
  288. bool tbc = false;
  289. for (int k=0; k<rs; ++k) {
  290. dim_t ls = len_s(k);
  291. if (((k<ra::rank_s<P>() && ls!=choose_len(std::decay_t<P>::len_s(k), ls)) || ...)) {
  292. return 0;
  293. }
  294. int anyk = ((k<ra::rank_s<P>() && (ANY==std::decay_t<P>::len_s(k))) + ...);
  295. int fixk = ((k<ra::rank_s<P>() && (0<=std::decay_t<P>::len_s(k))) + ...);
  296. tbc = tbc || (anyk>0 && anyk+fixk>1);
  297. }
  298. return tbc ? 1 : 2;
  299. }
  300. }
  301. constexpr bool
  302. check() const
  303. {
  304. if constexpr (constexpr int c = check_s(); 2==c) {
  305. return true;
  306. } else if constexpr (0==c) {
  307. return false;
  308. } else if constexpr (1==c) {
  309. for (int k=0; k<rank(); ++k) {
  310. #pragma GCC diagnostic push // gcc 12.2 and 13.2 with RA_DO_CHECK=0 and -fno-sanitize=all.
  311. #pragma GCC diagnostic warning "-Warray-bounds"
  312. dim_t ls = len(k);
  313. #pragma GCC diagnostic pop
  314. if (((k<ra::rank(std::get<I>(t)) && ls!=choose_len(std::get<I>(t).len(k), ls)) || ...)) {
  315. return false;
  316. }
  317. }
  318. }
  319. return true;
  320. }
  321. constexpr
  322. Match(P ... p_): t(p_ ...) // [ra1]
  323. {
  324. // TODO Maybe on ply would make checkp, agree_xxx() unnecessary.
  325. if constexpr (checkp && !(has_len<P> || ...)) {
  326. constexpr int c = check_s();
  327. static_assert(0!=c, "Mismatched shapes."); // FIXME c++26
  328. if constexpr (1==c) {
  329. RA_CHECK(check(), "Mismatched shapes", format_array(ra::shape(p_), {.shape=noshape, .open=" [", .close="]"}) ..., ".");
  330. }
  331. }
  332. }
  333. consteval static rank_t
  334. rank() requires (ANY!=rs)
  335. {
  336. return rs;
  337. }
  338. constexpr rank_t
  339. rank() const requires (ANY==rs)
  340. {
  341. rank_t r = BAD;
  342. ((r = choose_rank(r, ra::rank(std::get<I>(t)))), ...);
  343. assert(ANY!=r); // not at runtime
  344. return r;
  345. }
  346. // first nonnegative size, if none first ANY, if none then BAD
  347. constexpr static dim_t
  348. len_s(int k)
  349. {
  350. auto f = [&k]<class A>(dim_t s) {
  351. constexpr rank_t ar = ra::rank_s<A>();
  352. return (ar<0 || k<ar) ? choose_len(s, A::len_s(k)) : s;
  353. };
  354. dim_t s = BAD; ((s>=0 ? s : s = f.template operator()<std::decay_t<P>>(s)), ...);
  355. return s;
  356. }
  357. constexpr static dim_t
  358. len(int k) requires (requires { P::len(k); } && ...)
  359. {
  360. return len_s(k);
  361. }
  362. constexpr dim_t
  363. len(int k) const requires (!(requires { P::len(k); } && ...))
  364. {
  365. auto f = [&k](dim_t s, auto const & a) {
  366. return k<ra::rank(a) ? choose_len(s, a.len(k)) : s;
  367. };
  368. dim_t s = BAD; ((s>=0 ? s : s = f(s, std::get<I>(t))), ...);
  369. assert(ANY!=s); // not at runtime
  370. return s;
  371. }
  372. // could preserve static, but ply doesn't use it atm.
  373. constexpr auto
  374. step(int i) const
  375. {
  376. return std::make_tuple(std::get<I>(t).step(i) ...);
  377. }
  378. constexpr void
  379. adv(rank_t k, dim_t d)
  380. {
  381. (std::get<I>(t).adv(k, d), ...);
  382. }
  383. constexpr bool
  384. keep_step(dim_t st, int z, int j) const requires (!(requires { P::keep_step(st, z, j); } && ...))
  385. {
  386. return (std::get<I>(t).keep_step(st, z, j) && ...);
  387. }
  388. constexpr static bool
  389. keep_step(dim_t st, int z, int j) requires (requires { P::keep_step(st, z, j); } && ...)
  390. {
  391. return (std::decay_t<P>::keep_step(st, z, j) && ...);
  392. }
  393. constexpr auto save() const { return std::make_tuple(std::get<I>(t).save() ...); }
  394. constexpr void load(auto const & pp) { ((std::get<I>(t).load(std::get<I>(pp))), ...); }
  395. constexpr void mov(auto const & s) { ((std::get<I>(t).mov(std::get<I>(s))), ...); }
  396. };
  397. // ---------------------------
  398. // reframe
  399. // ---------------------------
  400. template <dim_t N, class T> constexpr T samestep = N;
  401. template <dim_t N, class ... T> constexpr std::tuple<T ...> samestep<N, std::tuple<T ...>> = { samestep<N, T> ... };
  402. // Transpose variant for IteratorConcepts. As in transpose(), one names the destination axis for
  403. // each original axis. However, axes may not be repeated. Used in the rank conjunction below.
  404. // Dest is a list of destination axes [l0 l1 ... li ... l(rank(A)-1)].
  405. // The dimensions of the reframed A are numbered as [0 ... k ... max(l)-1].
  406. // If li = k for some i, then axis k of the reframed A moves on axis i of the original iterator A.
  407. // If not, then axis k of the reframed A is 'dead' and doesn't move the iterator.
  408. // TODO invalid for ANY, since Dest is compile time. [ra7]
  409. template <class Dest, IteratorConcept A>
  410. struct Reframe
  411. {
  412. A a;
  413. consteval static rank_t
  414. rank()
  415. {
  416. return 1 + std::apply([](auto ... i) { int r=-1; ((r=std::max(r, int(i))), ...); return r; }, Dest {});
  417. }
  418. constexpr static int orig(int k)
  419. {
  420. return mp::int_list_index<Dest>(k);
  421. }
  422. constexpr static dim_t len_s(int k)
  423. {
  424. int l=orig(k);
  425. return l>=0 ? std::decay_t<A>::len_s(l) : BAD;
  426. }
  427. constexpr static dim_t
  428. len(int k) requires (requires { std::decay_t<A>::len(k); })
  429. {
  430. return len_s(k);
  431. }
  432. constexpr dim_t
  433. len(int k) const requires (!(requires { std::decay_t<A>::len(k); }))
  434. {
  435. int l=orig(k);
  436. return l>=0 ? a.len(l) : BAD;
  437. }
  438. constexpr auto
  439. step(int k) const
  440. {
  441. int l=orig(k);
  442. return l>=0 ? a.step(l) : samestep<0, decltype(a.step(l))>;
  443. }
  444. constexpr void
  445. adv(rank_t k, dim_t d)
  446. {
  447. int l=orig(k);
  448. if (l>=0) { a.adv(l, d); }
  449. }
  450. constexpr static bool
  451. keep_step(dim_t st, int z, int j) requires (requires { std::decay_t<A>::keep_step(st, z, j); })
  452. {
  453. int wz=orig(z), wj=orig(j);
  454. return wz>=0 && wj>=0 && std::decay_t<A>::keep_step(st, wz, wj);
  455. }
  456. constexpr bool
  457. keep_step(dim_t st, int z, int j) const requires (!(requires { std::decay_t<A>::keep_step(st, z, j); }))
  458. {
  459. int wz=orig(z), wj=orig(j);
  460. return wz>=0 && wj>=0 && a.keep_step(st, wz, wj);
  461. }
  462. constexpr decltype(auto)
  463. at(auto const & i) const
  464. {
  465. return a.at(std::apply([&i](auto ... t) { return std::array<dim_t, sizeof...(t)> { i[t] ... }; }, Dest {}));
  466. }
  467. constexpr decltype(auto) operator*() const { return *a; }
  468. constexpr auto save() const { return a.save(); }
  469. constexpr void load(auto const & p) { a.load(p); }
  470. // FIXME only if Dest preserves axis order (?) which is how wrank works
  471. constexpr void mov(auto const & s) { a.mov(s); }
  472. };
  473. // Optimize no-op case. TODO If A is CellBig, etc. beat Dest on it, same for eventual transpose_expr<>.
  474. template <class Dest, class A>
  475. constexpr decltype(auto)
  476. reframe(A && a)
  477. {
  478. if constexpr (std::is_same_v<Dest, mp::iota<Reframe<Dest, A>::rank()>>) {
  479. return RA_FWD(a);
  480. } else {
  481. return Reframe<Dest, A> { RA_FWD(a) };
  482. }
  483. }
  484. // ---------------------------
  485. // verbs and rank conjunction
  486. // ---------------------------
  487. template <class cranks_, class Op_>
  488. struct Verb
  489. {
  490. using cranks = cranks_;
  491. using Op = Op_;
  492. Op op;
  493. };
  494. RA_IS_DEF(is_verb, (std::is_same_v<A, Verb<typename A::cranks, typename A::Op>>))
  495. template <class cranks, class Op>
  496. constexpr auto
  497. wrank(cranks cranks_, Op && op) { return Verb<cranks, Op> { RA_FWD(op) }; }
  498. template <rank_t ... crank, class Op>
  499. constexpr auto
  500. wrank(Op && op) { return Verb<mp::int_list<crank ...>, Op> { RA_FWD(op) }; }
  501. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  502. struct Framematch_def;
  503. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  504. using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
  505. template <class A, class B>
  506. struct max_i
  507. {
  508. constexpr static int value = (A::value == choose_rank(A::value, B::value)) ? 0 : 1;
  509. };
  510. // Get a list (per argument) of lists of live axes. The last frame match is handled by standard prefix matching.
  511. template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
  512. struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  513. {
  514. static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  515. // live = number of live axes on this frame, for each argument. // TODO crank negative, inf.
  516. using live = mp::int_list<(rank_s<Ti>() - mp::len<Ri> - crank::value) ...>;
  517. using frameaxes = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri> - crank::value), skip>> ...>;
  518. using FM = Framematch<W, std::tuple<Ti ...>, frameaxes, skip + mp::ref<live, mp::indexof<max_i, live>>::value>;
  519. using R = typename FM::R;
  520. template <class VV> constexpr static decltype(auto) op(VV && v) { return FM::op(RA_FWD(v).op); } // cf [ra31]
  521. };
  522. // Terminal case where V doesn't have rank (is a raw op()).
  523. template <class V, class ... Ti, class ... Ri, rank_t skip>
  524. struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  525. {
  526. static_assert(sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  527. // TODO -crank::value when the actual verb rank is used (eg to use CellBig<... that_rank> instead of just begin()).
  528. using R = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri>), skip>> ...>;
  529. template <class VV> constexpr static decltype(auto) op(VV && v) { return RA_FWD(v); }
  530. };
  531. // ---------------
  532. // explicit agreement checks
  533. // ---------------
  534. constexpr bool
  535. agree(auto && ... p) { return agree_(ra::start(RA_FWD(p)) ...); }
  536. // 0: fail, 1: rt, 2: pass
  537. constexpr int
  538. agree_s(auto && ... p) { return agree_s_(ra::start(RA_FWD(p)) ...); }
  539. template <class Op, class ... P> requires (is_verb<Op>)
  540. constexpr bool
  541. agree_op(Op && op, P && ... p) { return agree_verb(mp::iota<sizeof...(P)> {}, RA_FWD(op), RA_FWD(p) ...); }
  542. template <class Op, class ... P> requires (!is_verb<Op>)
  543. constexpr bool
  544. agree_op(Op && op, P && ... p) { return agree(RA_FWD(p) ...); }
  545. template <class ... P>
  546. constexpr bool
  547. agree_(P && ... p) { return (Match<false, std::tuple<P ...>> { RA_FWD(p) ... }).check(); }
  548. template <class ... P>
  549. constexpr int
  550. agree_s_(P && ... p) { return Match<false, std::tuple<P ...>>::check_s(); }
  551. template <class V, class ... T, int ... i>
  552. constexpr bool
  553. agree_verb(mp::int_list<i ...>, V && v, T && ... t)
  554. {
  555. using FM = Framematch<V, std::tuple<T ...>>;
  556. return agree_op(FM::op(RA_FWD(v)), reframe<mp::ref<typename FM::R, i>>(ra::start(RA_FWD(t))) ...);
  557. }
  558. // ---------------------------
  559. // operator expression
  560. // ---------------------------
  561. template <class Op, class T, class K=mp::iota<mp::len<T>>> struct Expr;
  562. template <class Op, IteratorConcept ... P, int ... I>
  563. struct Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  564. {
  565. using Match_ = Match<true, std::tuple<P ...>>;
  566. using Match_::t, Match_::rs, Match_::rank;
  567. Op op;
  568. constexpr Expr(Op op_, P ... p_): Match_(p_ ...), op(op_) {} // [ra1]
  569. RA_ASSIGNOPS_SELF(Expr)
  570. RA_ASSIGNOPS_DEFAULT_SET
  571. constexpr decltype(auto) at(auto const & j) const { return std::invoke(op, std::get<I>(t).at(j) ...); }
  572. constexpr decltype(auto) operator*() const { return std::invoke(op, *std::get<I>(t) ...); }
  573. // needed for rs==ANY, which don't decay to scalar when used as operator arguments.
  574. constexpr
  575. operator decltype(std::invoke(op, *std::get<I>(t) ...)) () const
  576. {
  577. if constexpr (1!=size_s<Expr>()) {
  578. RA_CHECK(1==size(*this), "Bad conversion to scalar from shape [", ra::noshape, ra::shape(*this), "].");
  579. }
  580. return *(*this);
  581. }
  582. };
  583. template <class Op, IteratorConcept ... P>
  584. constexpr bool is_special_def<Expr<Op, std::tuple<P ...>>> = (is_special<P> || ...);
  585. template <class V, class ... T, int ... i>
  586. constexpr auto
  587. expr_verb(mp::int_list<i ...>, V && v, T && ... t)
  588. {
  589. using FM = Framematch<V, std::tuple<T ...>>;
  590. return expr(FM::op(RA_FWD(v)), reframe<mp::ref<typename FM::R, i>>(RA_FWD(t)) ...);
  591. }
  592. template <class Op, class ... P>
  593. constexpr auto
  594. expr(Op && op, P && ... p)
  595. {
  596. if constexpr (is_verb<Op>) {
  597. return expr_verb(mp::iota<sizeof...(P)> {}, RA_FWD(op), RA_FWD(p) ...);
  598. } else {
  599. return Expr<Op, std::tuple<P ...>> { RA_FWD(op), RA_FWD(p) ... };
  600. }
  601. }
  602. constexpr auto
  603. map(auto && op, auto && ... a) { return expr(RA_FWD(op), start(RA_FWD(a)) ...); }
  604. // ---------------------------
  605. // pick expression
  606. // ---------------------------
  607. template <class T, class J> struct pick_at_type;
  608. template <class ... P, class J> struct pick_at_type<std::tuple<P ...>, J>
  609. {
  610. using type = std::common_reference_t<decltype(std::declval<P>().at(std::declval<J>())) ...>;
  611. };
  612. template <std::size_t I, class T, class J>
  613. constexpr pick_at_type<mp::drop1<std::decay_t<T>>, J>::type
  614. pick_at(std::size_t p0, T && t, J const & j)
  615. {
  616. constexpr std::size_t N = mp::len<std::decay_t<T>> - 1;
  617. if constexpr (I < N) {
  618. return (p0==I) ? std::get<I+1>(t).at(j) : pick_at<I+1>(p0, t, j);
  619. } else {
  620. RA_CHECK(p0 < N, "Bad pick ", p0, " with ", N, " arguments."); std::abort();
  621. }
  622. }
  623. template <class T> struct pick_star_type;
  624. template <class ... P> struct pick_star_type<std::tuple<P ...>>
  625. {
  626. using type = std::common_reference_t<decltype(*std::declval<P>()) ...>;
  627. };
  628. template <std::size_t I, class T>
  629. constexpr pick_star_type<mp::drop1<std::decay_t<T>>>::type
  630. pick_star(std::size_t p0, T && t)
  631. {
  632. constexpr std::size_t N = mp::len<std::decay_t<T>> - 1;
  633. if constexpr (I < N) {
  634. return (p0==I) ? *(std::get<I+1>(t)) : pick_star<I+1>(p0, t);
  635. } else {
  636. RA_CHECK(p0 < N, "Bad pick ", p0, " with ", N, " arguments."); std::abort();
  637. }
  638. }
  639. template <class T, class K=mp::iota<mp::len<T>>> struct Pick;
  640. template <IteratorConcept ... P, int ... I>
  641. struct Pick<std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  642. {
  643. using Match_ = Match<true, std::tuple<P ...>>;
  644. using Match_::t, Match_::rs, Match_::rank;
  645. static_assert(sizeof...(P)>1);
  646. constexpr Pick(P ... p_): Match_(p_ ...) {} // [ra1]
  647. RA_ASSIGNOPS_SELF(Pick)
  648. RA_ASSIGNOPS_DEFAULT_SET
  649. constexpr decltype(auto) at(auto const & j) const { return pick_at<0>(std::get<0>(t).at(j), t, j); }
  650. constexpr decltype(auto) operator*() const { return pick_star<0>(*std::get<0>(t), t); }
  651. // needed for rs==ANY, which don't decay to scalar when used as operator arguments.
  652. constexpr
  653. operator decltype(pick_star<0>(*std::get<0>(t), t)) () const
  654. {
  655. if constexpr (1!=size_s<Pick>()) {
  656. RA_CHECK(1==size(*this), "Bad conversion to scalar from shape [", ra::noshape, ra::shape(*this), "].");
  657. }
  658. return *(*this);
  659. }
  660. };
  661. template <IteratorConcept ... P>
  662. constexpr bool is_special_def<Pick<std::tuple<P ...>>> = (is_special<P> || ...);
  663. template <class ... P>
  664. Pick(P && ... p) -> Pick<std::tuple<P ...>>;
  665. constexpr auto
  666. pick(auto && ... p) { return Pick { start(RA_FWD(p)) ... }; }
  667. } // namespace ra