atom.hh 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Terminal nodes for expression templates.
  3. // (c) Daniel Llorens - 2011-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include <utility>
  10. #include <cassert>
  11. #include "bootstrap.hh"
  12. // --------------------
  13. // error handling. See examples/throw.cc for how to customize.
  14. // --------------------
  15. #include <iostream> // might not be needed with a different RA_ASSERT.
  16. #ifndef RA_ASSERT
  17. #define RA_ASSERT(cond, ...) \
  18. { \
  19. if (std::is_constant_evaluated()) { \
  20. assert(cond /* FIXME show args */); \
  21. } else { \
  22. if (!(cond)) [[unlikely]] { \
  23. std::cerr << ra::format("**** ra (", std::source_location::current(), "): ", ##__VA_ARGS__, " ****") << std::endl; \
  24. std::abort(); \
  25. } \
  26. } \
  27. }
  28. #endif
  29. #if defined(RA_DO_CHECK) && RA_DO_CHECK==0
  30. #define RA_CHECK( ... )
  31. #else
  32. #define RA_CHECK( ... ) RA_ASSERT( __VA_ARGS__ )
  33. #endif
  34. #define RA_AFTER_CHECK Yes
  35. namespace ra {
  36. constexpr bool inside(dim_t i, dim_t b) { return 0<=i && i<b; }
  37. // --------------------
  38. // terminal types
  39. // --------------------
  40. // Rank-0 IteratorConcept. Can be used on foreign objects, or as alternative to the rank conjunction.
  41. // We still want f(scalar(C)) to be f(C) and not map(f, C), this is controlled by tomap/toreduce.
  42. template <class C>
  43. struct Scalar
  44. {
  45. C c;
  46. RA_DEF_ASSIGNOPS_DEFAULT_SET
  47. consteval static rank_t rank() { return 0; }
  48. constexpr static dim_t len_s(int k) { std::abort(); }
  49. constexpr static dim_t len(int k) { std::abort(); }
  50. constexpr static dim_t step(int k) { return 0; }
  51. constexpr static void adv(rank_t k, dim_t d) {}
  52. constexpr static bool keep_step(dim_t st, int z, int j) { return true; }
  53. constexpr decltype(auto) at(auto && j) const { return c; }
  54. constexpr C & operator*() requires (std::is_lvalue_reference_v<C>) { return c; } // [ra37]
  55. constexpr C const & operator*() requires (!std::is_lvalue_reference_v<C>) { return c; }
  56. constexpr C const & operator*() const { return c; } // [ra39]
  57. constexpr static int save() { return 0; }
  58. constexpr static void load(int) {}
  59. constexpr static void mov(dim_t d) {}
  60. };
  61. template <class C> constexpr auto
  62. scalar(C && c) { return Scalar<C> { RA_FWD(c) }; }
  63. template <class N> constexpr int
  64. maybe_any = []{
  65. if constexpr (is_constant<N>) {
  66. return N::value;
  67. } else {
  68. static_assert(std::is_integral_v<N> || !std::is_same_v<N, bool>);
  69. return ANY;
  70. }
  71. }();
  72. // IteratorConcept for foreign rank 1 objects.
  73. template <std::bidirectional_iterator I, class N>
  74. struct Ptr
  75. {
  76. static_assert(is_constant<N> || 0==rank_s<N>());
  77. constexpr static dim_t nn = maybe_any<N>;
  78. static_assert(nn==ANY || nn>=0 || nn==BAD);
  79. I i;
  80. [[no_unique_address]] N const n = {};
  81. constexpr Ptr(I i, N n): i(i), n(n) {}
  82. RA_DEF_ASSIGNOPS_SELF(Ptr)
  83. RA_DEF_ASSIGNOPS_DEFAULT_SET
  84. consteval static rank_t rank() { return 1; }
  85. constexpr static dim_t len_s(int k) { return nn; } // len(k==0) or step(k>=0)
  86. constexpr static dim_t len(int k) requires (nn!=ANY) { return len_s(k); }
  87. constexpr dim_t len(int k) const requires (nn==ANY) { return n; }
  88. constexpr static dim_t step(int k) { return k==0 ? 1 : 0; }
  89. constexpr void adv(rank_t k, dim_t d) { i += step(k) * d; }
  90. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  91. constexpr decltype(auto) at(auto && j) const requires (std::random_access_iterator<I>)
  92. {
  93. RA_CHECK(BAD==nn || inside(j[0], n), "Out of range for len[0]=", n, ": ", j[0], ".");
  94. return i[j[0]];
  95. }
  96. constexpr decltype(auto) operator*() const { return *i; }
  97. constexpr auto save() const { return i; }
  98. constexpr void load(I ii) { i = ii; }
  99. constexpr void mov(dim_t d) { i += d; }
  100. };
  101. template <class X> using iota_arg = std::conditional_t<is_constant<std::decay_t<X>> || is_scalar<std::decay_t<X>>, std::decay_t<X>, X>;
  102. template <class I, class N=dim_c<BAD>>
  103. constexpr auto
  104. ptr(I && i, N && n = N {})
  105. {
  106. // not decay_t bc of builtin arrays.
  107. if constexpr (std::ranges::bidirectional_range<std::remove_reference_t<I>>) {
  108. static_assert(std::is_same_v<dim_c<BAD>, N>, "Object has own length.");
  109. constexpr dim_t s = size_s<I>();
  110. if constexpr (ANY==s) {
  111. return ptr(std::begin(RA_FWD(i)), std::ssize(i));
  112. } else {
  113. return ptr(std::begin(RA_FWD(i)), ic<s>);
  114. }
  115. } else if constexpr (std::bidirectional_iterator<std::decay_t<I>>) {
  116. if constexpr (std::is_integral_v<N>) {
  117. RA_CHECK(n>=0, "Bad ptr length ", n, ".");
  118. }
  119. return Ptr<std::decay_t<I>, iota_arg<N>> { i, RA_FWD(n) };
  120. } else {
  121. static_assert(always_false<I>, "Bad type for ptr().");
  122. }
  123. }
  124. // Sequence and IteratorConcept for same. Iota isn't really a terminal, but its exprs must all have rank 0.
  125. // FIXME w is a custom Reframe mechanism inherited from TensorIndex. Generalize/unify
  126. // FIXME Sequence should be its own type, we can't represent a ct origin bc IteratorConcept interface takes up i.
  127. template <int w, class N_, class O, class S_>
  128. struct Iota
  129. {
  130. using N = std::decay_t<N_>;
  131. using S = std::decay_t<S_>;
  132. static_assert(w>=0);
  133. static_assert(is_constant<S> || 0==rank_s<S>());
  134. static_assert(is_constant<N> || 0==rank_s<N>());
  135. constexpr static dim_t nn = maybe_any<N>;
  136. static_assert(nn==ANY || nn>=0 || nn==BAD);
  137. [[no_unique_address]] N const n = {};
  138. O i = {};
  139. [[no_unique_address]] S const s = {};
  140. constexpr static S gets() requires (is_constant<S>) { return S {}; }
  141. constexpr O gets() const requires (!is_constant<S>) { return s; }
  142. consteval static rank_t rank() { return w+1; }
  143. constexpr static dim_t len_s(int k) { return k==w ? nn : BAD; } // len(0<=k<=w) or step(0<=k)
  144. constexpr static dim_t len(int k) requires (is_constant<N>) { return len_s(k); }
  145. constexpr dim_t len(int k) const requires (!is_constant<N>) { return k==w ? n : BAD; }
  146. constexpr static dim_t step(rank_t k) { return k==w ? 1 : 0; }
  147. constexpr void adv(rank_t k, dim_t d) { i += O(step(k) * d) * O(s); }
  148. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  149. constexpr auto at(auto && j) const
  150. {
  151. RA_CHECK(BAD==nn || inside(j[0], n), "Out of range for len[0]=", n, ": ", j[0], ".");
  152. return i + O(j[w])*O(s);
  153. }
  154. constexpr O operator*() const { return i; }
  155. constexpr auto save() const { return i; }
  156. constexpr void load(O ii) { i = ii; }
  157. constexpr void mov(dim_t d) { i += O(d)*O(s); }
  158. };
  159. template <int w=0, class O=dim_t, class N=dim_c<BAD>, class S=dim_c<1>>
  160. constexpr auto
  161. iota(N && n = N {}, O && org = 0,
  162. S && s = [] {
  163. if constexpr (std::is_integral_v<S>) {
  164. return S(1);
  165. } else if constexpr (is_constant<S>) {
  166. static_assert(1==S::value);
  167. return S {};
  168. } else {
  169. static_assert(always_false<S>, "Bad step type for Iota.");
  170. }
  171. }())
  172. {
  173. if constexpr (std::is_integral_v<N>) {
  174. RA_CHECK(n>=0, "Bad iota length ", n, ".");
  175. }
  176. return Iota<w, iota_arg<N>, iota_arg<O>, iota_arg<S>> { RA_FWD(n), RA_FWD(org), RA_FWD(s) };
  177. }
  178. #define DEF_TENSORINDEX(w) constexpr auto JOIN(_, w) = iota<w>();
  179. FOR_EACH(DEF_TENSORINDEX, 0, 1, 2, 3, 4);
  180. #undef DEF_TENSORINDEX
  181. RA_IS_DEF(is_iota, false)
  182. // BAD is excluded from beating to allow B = A(... i ...) to use B's len. FIXME find a way?
  183. template <class N, class O, class S>
  184. constexpr bool is_iota_def<Iota<0, N, O, S>> = (BAD != Iota<0, N, O, S>::nn);
  185. template <class I>
  186. constexpr bool
  187. inside(I const & i, dim_t l) requires (is_iota<I>)
  188. {
  189. return (inside(i.i, l) && inside(i.i+(i.n-1)*i.s, l)) || (0==i.n /* don't bother */);
  190. }
  191. // Never ply(), solely to be rewritten.
  192. constexpr struct Len
  193. {
  194. consteval static rank_t rank() { return 0; }
  195. constexpr static dim_t len_s(int k) { std::abort(); }
  196. constexpr static dim_t len(int k) { std::abort(); }
  197. constexpr static dim_t step(int k) { std::abort(); }
  198. constexpr static void adv(rank_t k, dim_t d) { std::abort(); }
  199. constexpr static bool keep_step(dim_t st, int z, int j) { std::abort(); }
  200. constexpr static int save() { std::abort(); }
  201. constexpr static void load(int) { std::abort(); }
  202. constexpr dim_t operator*() const { std::abort(); }
  203. constexpr static void mov(dim_t d) { std::abort(); }
  204. } len;
  205. // protect exprs with Len from reduction.
  206. template <> constexpr bool is_special_def<Len> = true;
  207. RA_IS_DEF(has_len, false);
  208. // --------------
  209. // coerce potential Iterators
  210. // --------------
  211. template <class T>
  212. constexpr void
  213. start(T && t) { static_assert(always_false<T>, "Type cannot be start()ed."); }
  214. template <class T> requires (is_fov<T>)
  215. constexpr auto
  216. start(T && t) { return ra::ptr(RA_FWD(t)); }
  217. template <class T>
  218. constexpr auto
  219. start(std::initializer_list<T> v) { return ra::ptr(v.begin(), v.size()); }
  220. template <class T> requires (is_scalar<T>)
  221. constexpr auto
  222. start(T && t) { return ra::scalar(RA_FWD(t)); }
  223. // forward declare for Match; implemented in small.hh.
  224. template <class T> requires (is_builtin_array<T>)
  225. constexpr auto
  226. start(T && t);
  227. // TODO fovs? arbitrary exprs?
  228. template <int cr, class A> constexpr auto
  229. iter(A && a) { return RA_FWD(a).template iter<cr>(); }
  230. // neither CellBig nor CellSmall will retain rvalues [ra4].
  231. template <SliceConcept T>
  232. constexpr auto
  233. start(T && t) { return iter<0>(RA_FWD(t)); }
  234. RA_IS_DEF(is_ra_scalar, (std::same_as<A, Scalar<decltype(std::declval<A>().c)>>))
  235. template <class T> requires (is_ra_scalar<T>)
  236. constexpr decltype(auto)
  237. start(T && t) { return RA_FWD(t); }
  238. // iterators need to be restarted on each use (eg ra::cross()) [ra35].
  239. template <class T> requires (is_iterator<T> && !is_ra_scalar<T>)
  240. constexpr auto
  241. start(T && t) { return RA_FWD(t); }
  242. } // namespace ra