ply.hh 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Expression traversal.
  3. // (c) Daniel Llorens - 2013-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // TODO Make traversal order a parameter, some operations (e.g. output, ravel) require specific orders.
  9. // TODO Better traversal. Tiling, etc. (see eval.cc in Blitz++). Unit step case?
  10. // TODO std::execution::xxx-policy
  11. // TODO Validate output argument strides.
  12. #pragma once
  13. #include "expr.hh"
  14. namespace ra {
  15. template <class A>
  16. constexpr decltype(auto)
  17. VALUE(A && a)
  18. {
  19. if constexpr (is_scalar<A>) {
  20. return RA_FWD(a); // [ra8]
  21. } else if constexpr (is_iterator<A>) {
  22. return *a; // no need to start() for one
  23. } else {
  24. return *(ra::start(RA_FWD(a)));
  25. }
  26. }
  27. template <class A> using value_t = std::remove_volatile_t<std::remove_reference_t<decltype(VALUE(std::declval<A>()))>>;
  28. template <class A> using ncvalue_t = std::remove_const_t<value_t<A>>;
  29. // ---------------------
  30. // replace Len in expr tree.
  31. // ---------------------
  32. template <>
  33. constexpr bool has_len_def<Len> = true;
  34. template <IteratorConcept ... P>
  35. constexpr bool has_len_def<Pick<std::tuple<P ...>>> = (has_len<P> || ...);
  36. template <class Op, IteratorConcept ... P>
  37. constexpr bool has_len_def<Expr<Op, std::tuple<P ...>>> = (has_len<P> || ...);
  38. template <int w, class I, class N, class S>
  39. constexpr bool has_len_def<Iota<w, I, N, S>> = (has_len<I> || has_len<N> || has_len<S>);
  40. template <class I, class N, class S>
  41. constexpr bool has_len_def<Ptr<I, N, S>> = has_len<N> || has_len<S>;
  42. template <class E>
  43. struct WithLen
  44. {
  45. static_assert(!has_len<E>, "Unhandled len.");
  46. constexpr static decltype(auto)
  47. f(auto ln, auto && e)
  48. {
  49. return RA_FWD(e);
  50. }
  51. };
  52. template <>
  53. struct WithLen<Len>
  54. {
  55. constexpr static decltype(auto)
  56. f(auto ln, auto && e)
  57. {
  58. return Scalar<decltype(ln)>(ln);
  59. }
  60. };
  61. template <class Op, IteratorConcept ... P, int ... I> requires (has_len<P> || ...)
  62. struct WithLen<Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>>
  63. {
  64. constexpr static decltype(auto)
  65. f(auto ln, auto && e)
  66. {
  67. return expr(RA_FWD(e).op, WithLen<std::decay_t<P>>::f(ln, std::get<I>(RA_FWD(e).t)) ...);
  68. }
  69. };
  70. template <IteratorConcept ... P, int ... I> requires (has_len<P> || ...)
  71. struct WithLen<Pick<std::tuple<P ...>, mp::int_list<I ...>>>
  72. {
  73. constexpr static decltype(auto)
  74. f(auto ln, auto && e)
  75. {
  76. return pick(WithLen<std::decay_t<P>>::f(ln, std::get<I>(RA_FWD(e).t)) ...);
  77. }
  78. };
  79. template <int w, class I, class N, class S> requires (has_len<I> || has_len<N> || has_len<S>)
  80. struct WithLen<Iota<w, I, N, S>>
  81. {
  82. constexpr static decltype(auto)
  83. f(auto ln, auto && e)
  84. {
  85. // final iota types must be either is_constant or is_scalar.
  86. return iota<w>(VALUE(WithLen<std::decay_t<N>>::f(ln, RA_FWD(e).n)),
  87. VALUE(WithLen<std::decay_t<I>>::f(ln, RA_FWD(e).i)),
  88. VALUE(WithLen<std::decay_t<S>>::f(ln, RA_FWD(e).s)));
  89. }
  90. };
  91. template <class I, class N, class S> requires (has_len<N> || has_len<S>)
  92. struct WithLen<Ptr<I, N, S>>
  93. {
  94. constexpr static decltype(auto)
  95. f(auto ln, auto && e)
  96. {
  97. return ptr(RA_FWD(e).i,
  98. VALUE(WithLen<std::decay_t<N>>::f(ln, RA_FWD(e).n)),
  99. VALUE(WithLen<std::decay_t<S>>::f(ln, RA_FWD(e).s)));
  100. }
  101. };
  102. template <class Ln, class E>
  103. constexpr decltype(auto)
  104. with_len(Ln ln, E && e)
  105. {
  106. static_assert(std::is_integral_v<std::decay_t<Ln>> || is_constant<std::decay_t<Ln>>);
  107. return WithLen<std::decay_t<E>>::f(ln, RA_FWD(e));
  108. }
  109. // --------------
  110. // ply, run time order/rank.
  111. // --------------
  112. struct Nop {};
  113. // step() must give 0 for k>=their own rank, to allow frame matching.
  114. template <IteratorConcept A, class Early = Nop>
  115. constexpr auto
  116. ply_ravel(A && a, Early && early = Nop {})
  117. {
  118. rank_t rank = ra::rank(a);
  119. // must avoid 0-length vlas [ra40].
  120. if (0>=rank) {
  121. if (0>rank) [[unlikely]] { std::abort(); }
  122. if constexpr (requires {early.def;}) {
  123. return (*a).value_or(early.def);
  124. } else {
  125. *a;
  126. return;
  127. }
  128. }
  129. // inside first. FIXME better heuristic - but first need a way to force row-major
  130. rank_t order[rank];
  131. for (rank_t i=0; i<rank; ++i) {
  132. order[i] = rank-1-i;
  133. }
  134. dim_t sha[rank], ind[rank] = {};
  135. // find outermost compact dim.
  136. rank_t * ocd = order;
  137. dim_t ss = a.len(*ocd);
  138. #pragma GCC diagnostic push // gcc 12.2 and 13.2 with RA_DO_CHECK=0 and -fno-sanitize=all
  139. #pragma GCC diagnostic warning "-Warray-bounds"
  140. for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
  141. ss *= a.len(*ocd);
  142. }
  143. for (int k=0; k<rank; ++k) {
  144. // ss takes care of the raveled dimensions ss.
  145. if (0>=(sha[k]=a.len(ocd[k]))) {
  146. if (0>sha[k]) [[unlikely]] { std::abort(); }
  147. if constexpr (requires {early.def;}) {
  148. return early.def;
  149. } else {
  150. return;
  151. }
  152. }
  153. }
  154. auto ss0 = a.step(order[0]);
  155. for (;;) {
  156. auto place = a.save();
  157. for (dim_t s=ss; --s>=0; a.mov(ss0)) {
  158. if constexpr (requires {early.def;}) {
  159. if (auto stop = *a) {
  160. return stop.value();
  161. }
  162. } else {
  163. *a;
  164. }
  165. }
  166. a.load(place); // FIXME wasted if k=0. Cf test/iota.cc
  167. for (int k=0; ; ++k) {
  168. if (k>=rank) {
  169. if constexpr (requires {early.def;}) {
  170. return early.def;
  171. } else {
  172. return;
  173. }
  174. } else if (++ind[k]<sha[k]) {
  175. a.adv(ocd[k], 1);
  176. break;
  177. } else {
  178. ind[k] = 0;
  179. a.adv(ocd[k], 1-sha[k]);
  180. }
  181. }
  182. }
  183. #pragma GCC diagnostic pop
  184. }
  185. // -------------------------
  186. // ply, compile time order/rank.
  187. // -------------------------
  188. template <auto order, int k, int urank, class A, class S, class Early>
  189. constexpr auto
  190. subply(A & a, dim_t s, S const & ss0, Early & early)
  191. {
  192. if constexpr (k < urank) {
  193. auto place = a.save();
  194. for (; --s>=0; a.mov(ss0)) {
  195. if constexpr (requires {early.def;}) {
  196. if (auto stop = *a) {
  197. return stop;
  198. }
  199. } else {
  200. *a;
  201. }
  202. }
  203. a.load(place); // FIXME wasted if k was 0 at the top
  204. } else {
  205. dim_t size = a.len(order[k]); // TODO precompute above
  206. for (dim_t i=0; i<size; ++i) {
  207. if constexpr (requires {early.def;}) {
  208. if (auto stop = subply<order, k-1, urank>(a, s, ss0, early)) {
  209. return stop;
  210. }
  211. } else {
  212. subply<order, k-1, urank>(a, s, ss0, early);
  213. }
  214. a.adv(order[k], 1);
  215. }
  216. a.adv(order[k], -size);
  217. }
  218. if constexpr (requires {early.def;}) {
  219. return static_cast<decltype(*a)>(std::nullopt);
  220. } else {
  221. return;
  222. }
  223. }
  224. // possibly pessimize ply_fixed(). See bench-dot [ra43]
  225. #ifndef RA_STATIC_UNROLL
  226. #define RA_STATIC_UNROLL 0
  227. #endif
  228. template <IteratorConcept A, class Early = Nop>
  229. constexpr decltype(auto)
  230. ply_fixed(A && a, Early && early = Nop {})
  231. {
  232. constexpr rank_t rank = rank_s<A>();
  233. static_assert(0<=rank, "ply_fixed needs static rank");
  234. // inside first. FIXME better heuristic - but first need a way to force row-major
  235. constexpr /* static P2647 gcc13 */ auto order = mp::tuple2array<int, mp::reverse<mp::iota<rank>>>();
  236. if constexpr (0==rank) {
  237. if constexpr (requires {early.def;}) {
  238. return (*a).value_or(early.def);
  239. } else {
  240. *a;
  241. return;
  242. }
  243. } else {
  244. auto ss0 = a.step(order[0]);
  245. // static keep_step implies all else is static.
  246. if constexpr (RA_STATIC_UNROLL && rank>1 && requires (dim_t st, rank_t z, rank_t j) { A::keep_step(st, z, j); }) {
  247. // find outermost compact dim.
  248. constexpr auto sj = [&order]
  249. {
  250. dim_t ss = A::len_s(order[0]);
  251. int j = 1;
  252. for (; j<rank && A::keep_step(ss, order[0], order[j]); ++j) {
  253. ss *= A::len_s(order[j]);
  254. }
  255. return std::make_tuple(ss, j);
  256. } ();
  257. if constexpr (requires {early.def;}) {
  258. return (subply<order, rank-1, std::get<1>(sj)>(a, std::get<0>(sj), ss0, early)).value_or(early.def);
  259. } else {
  260. subply<order, rank-1, std::get<1>(sj)>(a, std::get<0>(sj), ss0, early);
  261. }
  262. } else {
  263. #pragma GCC diagnostic push // gcc 12.2 and 13.2 with RA_DO_CHECK=0 and -fno-sanitize=all
  264. #pragma GCC diagnostic warning "-Warray-bounds"
  265. // not worth unrolling.
  266. if constexpr (requires {early.def;}) {
  267. return (subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early)).value_or(early.def);
  268. } else {
  269. subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early);
  270. }
  271. #pragma GCC diagnostic pop
  272. }
  273. }
  274. }
  275. // ---------------------------
  276. // default ply
  277. // ---------------------------
  278. template <IteratorConcept A, class Early = Nop>
  279. constexpr decltype(auto)
  280. ply(A && a, Early && early = Nop {})
  281. {
  282. static_assert(!has_len<A>, "len outside subscript context.");
  283. static_assert(0<=rank_s<A>() || ANY==rank_s<A>());
  284. if constexpr (ANY==size_s<A>()) {
  285. return ply_ravel(RA_FWD(a), RA_FWD(early));
  286. } else {
  287. return ply_fixed(RA_FWD(a), RA_FWD(early));
  288. }
  289. }
  290. constexpr void
  291. for_each(auto && op, auto && ... a) { ply(map(RA_FWD(op), RA_FWD(a) ...)); }
  292. template <class T> struct Default { T def; };
  293. template <class T> Default(T &&) -> Default<T>;
  294. constexpr decltype(auto)
  295. early(IteratorConcept auto && a, auto && def) { return ply(RA_FWD(a), Default { RA_FWD(def) }); }
  296. // --------------------
  297. // input/'output' iterator adapter. FIXME maybe random for rank 1?
  298. // --------------------
  299. template <IteratorConcept A>
  300. struct STLIterator
  301. {
  302. using difference_type = dim_t;
  303. using value_type = value_t<A>;
  304. A a;
  305. std::decay_t<decltype(ra::shape(a))> ind; // concrete type
  306. bool over;
  307. STLIterator(A a_): a(a_), ind(ra::shape(a_)), over(0==ra::size(a)) {}
  308. constexpr STLIterator(STLIterator &&) = default;
  309. constexpr STLIterator(STLIterator const &) = delete;
  310. constexpr STLIterator & operator=(STLIterator &&) = default;
  311. constexpr STLIterator & operator=(STLIterator const &) = delete;
  312. constexpr bool operator==(std::default_sentinel_t end) const { return over; }
  313. decltype(auto) operator*() const { return *a; }
  314. constexpr void
  315. next(rank_t k)
  316. {
  317. for (; k>=0; --k) {
  318. if (--ind[k]>0) {
  319. a.adv(k, 1);
  320. return;
  321. } else {
  322. ind[k] = a.len(k);
  323. a.adv(k, 1-a.len(k));
  324. }
  325. }
  326. over = true;
  327. }
  328. template <int k>
  329. constexpr void
  330. next()
  331. {
  332. if constexpr (k>=0) {
  333. if (--ind[k]>0) {
  334. a.adv(k, 1);
  335. } else {
  336. ind[k] = a.len(k);
  337. a.adv(k, 1-a.len(k));
  338. next<k-1>();
  339. }
  340. return;
  341. }
  342. over = true;
  343. }
  344. constexpr STLIterator & operator++() requires (ANY==rank_s<A>()) { next(rank(a)-1); return *this; }
  345. constexpr STLIterator & operator++() requires (ANY!=rank_s<A>()) { next<rank_s<A>()-1>(); return *this; }
  346. constexpr void operator++(int) { ++(*this); } // see p0541 and p2550. Or just avoid.
  347. };
  348. template <class A> STLIterator(A &&) -> STLIterator<A>;
  349. constexpr auto begin(is_ra auto && a) { return STLIterator(ra::start(RA_FWD(a))); }
  350. constexpr auto end(is_ra auto && a) { return std::default_sentinel; }
  351. constexpr auto range(is_ra auto && a) { return std::ranges::subrange(ra::begin(RA_FWD(a)), std::default_sentinel); }
  352. // unqualified might find .begin() anyway through std::begin etc (!)
  353. constexpr auto begin(is_ra auto && a) requires (requires { a.begin(); }) { static_assert(std::is_lvalue_reference_v<decltype(a)>); return a.begin(); }
  354. constexpr auto end(is_ra auto && a) requires (requires { a.end(); }) { static_assert(std::is_lvalue_reference_v<decltype(a)>); return a.end(); }
  355. constexpr auto range(is_ra auto && a) requires (requires { a.begin(); }) { static_assert(std::is_lvalue_reference_v<decltype(a)>); return std::ranges::subrange(a.begin(), a.end()); }
  356. // ---------------------------
  357. // i/o
  358. // ---------------------------
  359. template <class A>
  360. inline std::ostream &
  361. operator<<(std::ostream & o, FormatArray<A> const & fa)
  362. {
  363. static_assert(!has_len<A>, "len outside subscript context.");
  364. static_assert(BAD!=size_s<A>(), "Cannot print undefined size expr.");
  365. auto a = ra::start(fa.a); // [ra35]
  366. auto sha = shape(a);
  367. if (withshape==fa.fmt.shape || (defaultshape==fa.fmt.shape && size_s(a)==ANY)) {
  368. o << sha << '\n';
  369. }
  370. rank_t const rank = ra::rank(a);
  371. auto goin = [&](int k, auto & goin) -> void
  372. {
  373. if (k==rank) {
  374. o << *a;
  375. } else {
  376. o << fa.fmt.open;
  377. for (int i=0; i<sha[k]; ++i) {
  378. goin(k+1, goin);
  379. if (i+1<sha[k]) {
  380. a.adv(k, 1);
  381. o << (k==rank-1 ? fa.fmt.sep0 : fa.fmt.sepn);
  382. std::fill_n(std::ostream_iterator<char const *>(o, ""), std::max(0, rank-2-k), fa.fmt.rep);
  383. if (fa.fmt.align && k<rank-1) {
  384. std::fill_n(std::ostream_iterator<char const *>(o, ""), (k+1)*ra::size(fa.fmt.open), " ");
  385. }
  386. } else {
  387. a.adv(k, 1-sha[k]);
  388. break;
  389. }
  390. }
  391. o << fa.fmt.close;
  392. }
  393. };
  394. goin(0, goin);
  395. return o;
  396. }
  397. template <class C> requires (ANY!=size_s<C>() && !is_scalar<C>)
  398. inline std::istream &
  399. operator>>(std::istream & i, C & c)
  400. {
  401. for (auto & ci: c) { i >> ci; }
  402. return i;
  403. }
  404. template <class T, class A>
  405. inline std::istream &
  406. operator>>(std::istream & i, std::vector<T, A> & c)
  407. {
  408. if (dim_t n; i >> n) {
  409. RA_CHECK(n>=0, "Negative length in input [", n, "].");
  410. std::vector<T, A> cc(n);
  411. swap(c, cc);
  412. for (auto & ci: c) { i >> ci; }
  413. }
  414. return i;
  415. }
  416. template <class C> requires (ANY==size_s<C>() && !std::is_convertible_v<C, std::string_view>)
  417. inline std::istream &
  418. operator>>(std::istream & i, C & c)
  419. {
  420. if (decltype(shape(c)) s; i >> s) {
  421. RA_CHECK(every(start(s)>=0), "Negative length in input [", noshape, s, "].");
  422. C cc(s, ra::none);
  423. swap(c, cc);
  424. for (auto & ci: c) { i >> ci; }
  425. }
  426. return i;
  427. }
  428. } // namespace ra