ply.hh 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Expression traversal.
  3. // (c) Daniel Llorens - 2013-2024
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // TODO Make traversal order a parameter, some operations (e.g. output, ravel) require specific orders.
  9. // TODO Better traversal. Tiling, etc. (see eval.cc in Blitz++). Unit step case?
  10. // TODO std::execution::xxx-policy TODO Validate output argument strides.
  11. #pragma once
  12. #include "expr.hh"
  13. namespace ra {
  14. template <class A>
  15. constexpr decltype(auto)
  16. VALUE(A && a)
  17. {
  18. if constexpr (is_scalar<A>) { return RA_FWD(a); } // [ra8]
  19. else if constexpr (is_iterator<A>) { return *a; } // no need to start()
  20. else { return *(ra::start(RA_FWD(a))); }
  21. }
  22. template <class A> using value_t = std::remove_volatile_t<std::remove_reference_t<decltype(VALUE(std::declval<A>()))>>;
  23. template <class A> using ncvalue_t = std::remove_const_t<value_t<A>>;
  24. // ---------------------
  25. // replace Len in expr tree.
  26. // ---------------------
  27. template <>
  28. struct WLen<Len>
  29. {
  30. constexpr static decltype(auto)
  31. f(auto ln, auto && e)
  32. {
  33. return Scalar<decltype(ln)>(ln);
  34. }
  35. };
  36. template <class Op, IteratorConcept ... P, int ... I> requires (has_len<P> || ...)
  37. struct WLen<Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>>
  38. {
  39. constexpr static decltype(auto)
  40. f(auto ln, auto && e)
  41. {
  42. return expr(RA_FWD(e).op, wlen(ln, std::get<I>(RA_FWD(e).t)) ...);
  43. }
  44. };
  45. template <IteratorConcept ... P, int ... I> requires (has_len<P> || ...)
  46. struct WLen<Pick<std::tuple<P ...>, mp::int_list<I ...>>>
  47. {
  48. constexpr static decltype(auto)
  49. f(auto ln, auto && e)
  50. {
  51. return pick(wlen(ln, std::get<I>(RA_FWD(e).t)) ...);
  52. }
  53. };
  54. // final iota/ptr types must be either is_constant or is_scalar.
  55. template <int w, class I, class N, class S> requires (has_len<I> || has_len<N> || has_len<S>)
  56. struct WLen<Iota<w, I, N, S>>
  57. {
  58. constexpr static decltype(auto)
  59. f(auto ln, auto && e)
  60. {
  61. return iota<w>(VALUE(wlen(ln, RA_FWD(e).n)), VALUE(wlen(ln, RA_FWD(e).i)), VALUE(wlen(ln, RA_FWD(e).s)));
  62. }
  63. };
  64. template <class I, class N, class S> requires (has_len<N> || has_len<S>)
  65. struct WLen<Ptr<I, N, S>>
  66. {
  67. constexpr static decltype(auto)
  68. f(auto ln, auto && e)
  69. {
  70. return ptr(RA_FWD(e).i, VALUE(wlen(ln, RA_FWD(e).n)), VALUE(wlen(ln, RA_FWD(e).s)));
  71. }
  72. };
  73. // --------------
  74. // ply, run time order/rank.
  75. // --------------
  76. struct Nop {};
  77. // step() must give 0 for k>=their own rank, to allow frame matching.
  78. template <IteratorConcept A, class Early = Nop>
  79. constexpr auto
  80. ply_ravel(A && a, Early && early = Nop {})
  81. {
  82. rank_t rank = ra::rank(a);
  83. // must avoid 0-length vlas [ra40].
  84. if (0>=rank) {
  85. if (0>rank) [[unlikely]] { std::abort(); }
  86. if constexpr (requires {early.def;}) {
  87. return (*a).value_or(early.def);
  88. } else {
  89. *a;
  90. return;
  91. }
  92. }
  93. // inside first. FIXME better heuristic - but first need a way to force row-major
  94. rank_t order[rank];
  95. for (rank_t i=0; i<rank; ++i) {
  96. order[i] = rank-1-i;
  97. }
  98. dim_t sha[rank], ind[rank] = {};
  99. // find outermost compact dim.
  100. rank_t * ocd = order;
  101. dim_t ss = a.len(*ocd);
  102. #pragma GCC diagnostic push // gcc 12.2 and 13.2 with RA_DO_CHECK=0 and -fno-sanitize=all
  103. #pragma GCC diagnostic warning "-Warray-bounds"
  104. for (--rank, ++ocd; rank>0 && a.keep(ss, order[0], *ocd); --rank, ++ocd) {
  105. ss *= a.len(*ocd);
  106. }
  107. for (int k=0; k<rank; ++k) {
  108. // ss takes care of the raveled dimensions ss.
  109. if (0>=(sha[k]=a.len(ocd[k]))) {
  110. if (0>sha[k]) [[unlikely]] { std::abort(); }
  111. if constexpr (requires {early.def;}) {
  112. return early.def;
  113. } else {
  114. return;
  115. }
  116. }
  117. }
  118. auto ss0 = a.step(order[0]);
  119. for (;;) {
  120. auto place = a.save();
  121. for (dim_t s=ss; --s>=0; a.mov(ss0)) {
  122. if constexpr (requires {early.def;}) {
  123. if (auto stop = *a) {
  124. return stop.value();
  125. }
  126. } else {
  127. *a;
  128. }
  129. }
  130. a.load(place); // FIXME wasted if k=0. Cf test/iota.cc
  131. for (int k=0; ; ++k) {
  132. if (k>=rank) {
  133. if constexpr (requires {early.def;}) {
  134. return early.def;
  135. } else {
  136. return;
  137. }
  138. } else if (++ind[k]<sha[k]) {
  139. a.adv(ocd[k], 1);
  140. break;
  141. } else {
  142. ind[k] = 0;
  143. a.adv(ocd[k], 1-sha[k]);
  144. }
  145. }
  146. }
  147. #pragma GCC diagnostic pop
  148. }
  149. // -------------------------
  150. // ply, compile time order/rank.
  151. // -------------------------
  152. template <auto order, int k, int urank, class A, class S, class Early>
  153. constexpr auto
  154. subply(A & a, dim_t s, S const & ss0, Early & early)
  155. {
  156. if constexpr (k < urank) {
  157. auto place = a.save();
  158. for (; --s>=0; a.mov(ss0)) {
  159. if constexpr (requires {early.def;}) {
  160. if (auto stop = *a) {
  161. return stop;
  162. }
  163. } else {
  164. *a;
  165. }
  166. }
  167. a.load(place); // FIXME wasted if k was 0 at the top
  168. } else {
  169. dim_t size = a.len(order[k]); // TODO precompute above
  170. for (dim_t i=0; i<size; ++i) {
  171. if constexpr (requires {early.def;}) {
  172. if (auto stop = subply<order, k-1, urank>(a, s, ss0, early)) {
  173. return stop;
  174. }
  175. } else {
  176. subply<order, k-1, urank>(a, s, ss0, early);
  177. }
  178. a.adv(order[k], 1);
  179. }
  180. a.adv(order[k], -size);
  181. }
  182. if constexpr (requires {early.def;}) {
  183. return static_cast<decltype(*a)>(std::nullopt);
  184. } else {
  185. return;
  186. }
  187. }
  188. template <IteratorConcept A, class Early = Nop>
  189. constexpr decltype(auto)
  190. ply_fixed(A && a, Early && early = Nop {})
  191. {
  192. constexpr rank_t rank = rank_s(a);
  193. static_assert(0<=rank, "ply_fixed needs static rank");
  194. // inside first. FIXME better heuristic - but first need a way to force row-major
  195. constexpr auto order = mp::tuple2array<int, mp::reverse<mp::iota<rank>>>();
  196. if constexpr (0==rank) {
  197. if constexpr (requires {early.def;}) {
  198. return (*a).value_or(early.def);
  199. } else {
  200. *a;
  201. return;
  202. }
  203. } else {
  204. #pragma GCC diagnostic push
  205. #pragma GCC diagnostic warning "-Warray-bounds"
  206. auto ss0 = a.step(order[0]); // gcc 14.1 with RA_DO_CHECK=0 and sanitizer on
  207. #pragma GCC diagnostic pop
  208. if constexpr (requires {early.def;}) {
  209. return (subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early)).value_or(early.def);
  210. } else {
  211. subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early);
  212. }
  213. }
  214. }
  215. // ---------------------------
  216. // default ply
  217. // ---------------------------
  218. template <IteratorConcept A, class Early = Nop>
  219. constexpr decltype(auto)
  220. ply(A && a, Early && early = Nop {})
  221. {
  222. static_assert(0<=rank_s(a) || ANY==rank_s(a));
  223. if constexpr (ANY==size_s(a)) {
  224. return ply_ravel(RA_FWD(a), RA_FWD(early));
  225. } else {
  226. return ply_fixed(RA_FWD(a), RA_FWD(early));
  227. }
  228. }
  229. constexpr void
  230. for_each(auto && op, auto && ... a) { ply(map(RA_FWD(op), RA_FWD(a) ...)); }
  231. template <class T> struct Default { T def; };
  232. template <class T> Default(T &&) -> Default<T>;
  233. constexpr decltype(auto)
  234. early(IteratorConcept auto && a, auto && def) { return ply(RA_FWD(a), Default { RA_FWD(def) }); }
  235. // --------------------
  236. // input/'output' iterator adapter. FIXME maybe random for rank 1?
  237. // --------------------
  238. template <IteratorConcept A>
  239. struct STLIterator
  240. {
  241. using difference_type = dim_t;
  242. using value_type = value_t<A>;
  243. A a;
  244. std::decay_t<decltype(ra::shape(a))> ind; // concrete type
  245. bool over;
  246. STLIterator(A a_): a(a_), ind(ra::shape(a_)), over(0==ra::size(a)) {}
  247. constexpr STLIterator(STLIterator &&) = default;
  248. constexpr STLIterator(STLIterator const &) = delete;
  249. constexpr STLIterator & operator=(STLIterator &&) = default;
  250. constexpr STLIterator & operator=(STLIterator const &) = delete;
  251. constexpr bool operator==(std::default_sentinel_t end) const { return over; }
  252. decltype(auto) operator*() const { return *a; }
  253. constexpr void
  254. next(rank_t k)
  255. {
  256. for (; k>=0; --k) {
  257. if (--ind[k]>0) {
  258. a.adv(k, 1);
  259. return;
  260. } else {
  261. ind[k] = a.len(k);
  262. a.adv(k, 1-a.len(k));
  263. }
  264. }
  265. over = true;
  266. }
  267. template <int k>
  268. constexpr void
  269. next()
  270. {
  271. if constexpr (k>=0) {
  272. if (--ind[k]>0) {
  273. a.adv(k, 1);
  274. } else {
  275. ind[k] = a.len(k);
  276. a.adv(k, 1-a.len(k));
  277. next<k-1>();
  278. }
  279. return;
  280. }
  281. over = true;
  282. }
  283. constexpr STLIterator & operator++() { if constexpr (ANY==rank_s(a)) { next(rank(a)-1); } else { next<rank_s(a)-1>(); } return *this; }
  284. constexpr void operator++(int) { ++(*this); } // see p0541 and p2550. Or just avoid.
  285. };
  286. template <class A> STLIterator(A &&) -> STLIterator<A>;
  287. constexpr auto begin(is_ra auto && a) { return STLIterator(ra::start(RA_FWD(a))); }
  288. constexpr auto end(is_ra auto && a) { return std::default_sentinel; }
  289. constexpr auto range(is_ra auto && a) { return std::ranges::subrange(ra::begin(RA_FWD(a)), std::default_sentinel); }
  290. // unqualified might find .begin() anyway through std::begin etc (!)
  291. constexpr auto begin(is_ra auto && a) requires (requires { a.begin(); }) { static_assert(std::is_lvalue_reference_v<decltype(a)>); return a.begin(); }
  292. constexpr auto end(is_ra auto && a) requires (requires { a.end(); }) { static_assert(std::is_lvalue_reference_v<decltype(a)>); return a.end(); }
  293. constexpr auto range(is_ra auto && a) requires (requires { a.begin(); }) { static_assert(std::is_lvalue_reference_v<decltype(a)>); return std::ranges::subrange(a.begin(), a.end()); }
  294. // ---------------------------
  295. // i/o
  296. // ---------------------------
  297. template <class A>
  298. inline std::ostream &
  299. operator<<(std::ostream & o, FormatArray<A> const & fa)
  300. {
  301. auto a = ra::start(fa.a); // [ra35]
  302. static_assert(BAD!=size_s(a), "Cannot print undefined size expr.");
  303. auto sha = shape(a);
  304. // the following assert fixes a segfault in gcc11.3 test/io.c with -O3 -DRA_DO_CHECK=1.
  305. assert(every(ra::start(sha)>=0));
  306. // always print shape with defaultshape to avoid recursion on shape(shape(...)) = [1].
  307. if (withshape==fa.fmt.shape || (defaultshape==fa.fmt.shape && ANY==size_s(a))) {
  308. o << ra::defaultshape << sha << '\n';
  309. }
  310. rank_t const rank = ra::rank(a);
  311. auto goin = [&](this auto && goin, int k) -> void
  312. {
  313. if (k==rank) {
  314. o << *a;
  315. } else {
  316. o << fa.fmt.open;
  317. for (int i=0; i<sha[k]; ++i) {
  318. goin(k+1);
  319. if (i+1<sha[k]) {
  320. a.adv(k, 1);
  321. o << (k==rank-1 ? fa.fmt.sep0 : fa.fmt.sepn);
  322. std::fill_n(std::ostream_iterator<char const *>(o, ""), std::max(0, rank-2-k), fa.fmt.rep);
  323. if (fa.fmt.align && k<rank-1) {
  324. std::fill_n(std::ostream_iterator<char const *>(o, ""), (k+1)*ra::size(fa.fmt.open), " ");
  325. }
  326. } else {
  327. a.adv(k, 1-sha[k]);
  328. break;
  329. }
  330. }
  331. o << fa.fmt.close;
  332. }
  333. };
  334. goin(0);
  335. return o;
  336. }
  337. template <class C> requires (ANY!=size_s<C>() && !is_scalar<C>)
  338. inline std::istream &
  339. operator>>(std::istream & i, C & c)
  340. {
  341. for (auto & ci: c) { i >> ci; }
  342. return i;
  343. }
  344. template <class T, class A>
  345. inline std::istream &
  346. operator>>(std::istream & i, std::vector<T, A> & c)
  347. {
  348. if (dim_t n; i >> n) {
  349. RA_CHECK(n>=0, "Negative length in input [", n, "].");
  350. std::vector<T, A> cc(n);
  351. swap(c, cc);
  352. for (auto & ci: c) { i >> ci; }
  353. }
  354. return i;
  355. }
  356. template <class C> requires (ANY==size_s<C>() && !std::is_convertible_v<C, std::string_view>)
  357. inline std::istream &
  358. operator>>(std::istream & i, C & c)
  359. {
  360. if (decltype(shape(c)) s; i >> s) {
  361. RA_CHECK(every(start(s)>=0), "Negative length in input [", noshape, s, "].");
  362. C cc(s, ra::none);
  363. swap(c, cc);
  364. for (auto & ci: c) { i >> ci; }
  365. }
  366. return i;
  367. }
  368. } // namespace ra