ra.hh 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Operator overloads for expression templates, and root header.
  3. // (c) Daniel Llorens - 2014-2024
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "big.hh"
  10. #include <cmath>
  11. #include <complex>
  12. #ifndef RA_OPT
  13. #define RA_OPT optimize
  14. #endif
  15. #if defined(RA_DO_FMA)
  16. #elif defined(FP_FAST_FMA)
  17. #define RA_DO_FMA FP_FAST_FMA
  18. #else
  19. #define RA_DO_FMA 0
  20. #endif
  21. // Enable ADL with explicit template args. See http://stackoverflow.com/questions/9838862.
  22. template <class A> constexpr void transpose(ra::noarg);
  23. template <int A> constexpr void iter(ra::noarg);
  24. template <class T> constexpr void cast(ra::noarg);
  25. // ---------------------------
  26. // scalar overloads
  27. // ---------------------------
  28. // abs() needs no qualifying for ra:: types (ADL), shouldn't on pods either. FIXME maybe let user decide.
  29. // std::max/min are special, see DEF_NAME.
  30. using std::max, std::min, std::abs, std::fma, std::sqrt, std::pow, std::exp, std::swap,
  31. std::isfinite, std::isinf, std::isnan, std::clamp, std::lerp, std::conj, std::expm1;
  32. #define FOR_FLOAT(T) \
  33. constexpr T conj(T x) { return x; } \
  34. FOR_EACH(FOR_FLOAT, float, double)
  35. #undef FOR_FLOAT
  36. #define FOR_FLOAT(R) \
  37. constexpr std::complex<R> \
  38. fma(std::complex<R> const & a, std::complex<R> const & b, std::complex<R> const & c) \
  39. { \
  40. return std::complex<R>(fma(a.real(), b.real(), fma(-a.imag(), b.imag(), c.real())), \
  41. fma(a.real(), b.imag(), fma(a.imag(), b.real(), c.imag()))); \
  42. } \
  43. constexpr R \
  44. fma_sqrm(std::complex<R> const & a, R const & c) \
  45. { \
  46. return fma(a.real(), a.real(), fma(a.imag(), a.imag(), c)); \
  47. } \
  48. constexpr bool isfinite(std::complex<R> z) { return isfinite(z.real()) && isfinite(z.imag()); } \
  49. constexpr bool isnan(std::complex<R> z) { return isnan(z.real()) || isnan(z.imag()); } \
  50. constexpr bool isinf(std::complex<R> z) { return (isinf(z.real()) || isinf(z.imag())) && !isnan(z); }
  51. FOR_EACH(FOR_FLOAT, float, double)
  52. #undef FOR_FLOAT
  53. namespace ra {
  54. // As array op; special definitions for rank 0.
  55. template <class T> constexpr bool ra_is_real = std::numeric_limits<T>::is_integer || std::is_floating_point_v<T>;
  56. template <class T> requires (ra_is_real<T>) constexpr T amax(T const & x) { return x; }
  57. template <class T> requires (ra_is_real<T>) constexpr T amin(T const & x) { return x; }
  58. template <class T> requires (ra_is_real<T>) constexpr T sqr(T const & x) { return x*x; }
  59. #define RA_FOR_TYPES(T) \
  60. constexpr T arg(T x) { return T(0); } \
  61. constexpr T conj(T x) { return x; } \
  62. constexpr T mul_conj(T x, T y) { return x*y; } \
  63. constexpr T sqrm(T x) { return sqr(x); } \
  64. constexpr T sqrm(T x, T y) { return sqr(x-y); } \
  65. constexpr T dot(T x, T y) { return x*y; } \
  66. constexpr T fma_conj(T a, T b, T c) { return fma(a, b, c); } \
  67. constexpr T norm2(T x) { return std::abs(x); } \
  68. constexpr T norm2(T x, T y) { return std::abs(x-y); } \
  69. constexpr T rel_error(T a, T b) { auto den = (abs(a)+abs(b)); return den==0 ? 0. : 2.*norm2(a, b)/den; } \
  70. constexpr T const & real_part(T const & x) { return x; } \
  71. constexpr T & real_part(T & x) { return x; } \
  72. constexpr T imag_part(T x) { return T(0); }
  73. FOR_EACH(RA_FOR_TYPES, float, double)
  74. #undef RA_FOR_TYPES
  75. // FIXME few still inline should eventually be constexpr.
  76. #define RA_FOR_TYPES(R, C) \
  77. inline R arg(C x) { return std::arg(x); } \
  78. constexpr C sqr(C x) { return x*x; } \
  79. constexpr C dot(C x, C y) { return x*y; } \
  80. constexpr C xi(R x) { return C(0, x); } \
  81. constexpr C xi(C z) { return C(-z.imag(), z.real()); } \
  82. constexpr R real_part(C const & z) { return z.real(); } \
  83. constexpr R imag_part(C const & z) { return z.imag(); } \
  84. inline R & real_part(C & z) { return reinterpret_cast<R *>(&z)[0]; } \
  85. inline R & imag_part(C & z) { return reinterpret_cast<R *>(&z)[1]; } \
  86. constexpr R sqrm(C x) { return sqr(x.real())+sqr(x.imag()); } \
  87. constexpr R sqrm(C x, C y) { return sqr(x.real()-y.real())+sqr(x.imag()-y.imag()); } \
  88. constexpr R norm2(C x) { return hypot(x.real(), x.imag()); } \
  89. constexpr R norm2(C x, C y) { return sqrt(sqrm(x, y)); } \
  90. constexpr R rel_error(C a, C b) { auto den = (abs(a)+abs(b)); return den==0 ? 0. : 2.*norm2(a, b)/den; } \
  91. /* conj(a) * b + c */ \
  92. constexpr C \
  93. fma_conj(C const & a, C const & b, C const & c) \
  94. { \
  95. return C(fma(a.real(), b.real(), fma(a.imag(), b.imag(), c.real())), \
  96. fma(a.real(), b.imag(), fma(-a.imag(), b.real(), c.imag()))); \
  97. } \
  98. /* conj(a) * b */ \
  99. constexpr C \
  100. mul_conj(C const & a, C const & b) \
  101. { \
  102. return C(a.real()*b.real()+a.imag()*b.imag(), \
  103. a.real()*b.imag()-a.imag()*b.real()); \
  104. }
  105. RA_FOR_TYPES(float, std::complex<float>)
  106. RA_FOR_TYPES(double, std::complex<double>)
  107. #undef RA_FOR_TYPES
  108. template <class T> constexpr bool is_scalar_def<std::complex<T>> = true;
  109. template <int ... Iarg, class A>
  110. constexpr decltype(auto)
  111. transpose(mp::int_list<Iarg ...>, A && a) { return transpose<Iarg ...>(RA_FWD(a)); }
  112. constexpr bool
  113. odd(unsigned int N) { return N & 1; }
  114. // --------------------------------
  115. // optimization pass over expression templates.
  116. // --------------------------------
  117. template <class E> constexpr decltype(auto) optimize(E && e) { return RA_FWD(e); }
  118. // FIXME only reduces iota exprs as operated on in ra.hh (operators), not a tree like wlen() does.
  119. // TODO maybe don't opt iota(int)*real -> iota(real) since a+a+... != n*a
  120. template <class X> concept iota_op = ra::is_zero_or_scalar<X> && std::is_arithmetic_v<ncvalue_t<X>>;
  121. // TODO something to handle the & variants...
  122. #define ITEM(i) std::get<(i)>(e.t)
  123. // FIXME gets() vs p2781r2
  124. // qualified ra::iota is necessary to avoid ADLing to std::iota (test/headers.cc).
  125. template <is_iota I, iota_op J>
  126. constexpr auto
  127. optimize(Expr<std::plus<>, std::tuple<I, J>> && e)
  128. { return ra::iota(ITEM(0).n, ITEM(0).i+ITEM(1), ITEM(0).s); }
  129. template <iota_op I, is_iota J>
  130. constexpr auto
  131. optimize(Expr<std::plus<>, std::tuple<I, J>> && e)
  132. { return ra::iota(ITEM(1).n, ITEM(0)+ITEM(1).i, ITEM(1).s); }
  133. template <is_iota I, is_iota J>
  134. constexpr auto
  135. optimize(Expr<std::plus<>, std::tuple<I, J>> && e)
  136. { return ra::iota(maybe_len(e), ITEM(0).i+ITEM(1).i, ITEM(0).s+ITEM(1).s); }
  137. template <is_iota I, iota_op J>
  138. constexpr auto
  139. optimize(Expr<std::minus<>, std::tuple<I, J>> && e)
  140. { return ra::iota(ITEM(0).n, ITEM(0).i-ITEM(1), ITEM(0).s); }
  141. template <iota_op I, is_iota J>
  142. constexpr auto
  143. optimize(Expr<std::minus<>, std::tuple<I, J>> && e)
  144. { return ra::iota(ITEM(1).n, ITEM(0)-ITEM(1).i, -ITEM(1).s); }
  145. template <is_iota I, is_iota J>
  146. constexpr auto
  147. optimize(Expr<std::minus<>, std::tuple<I, J>> && e)
  148. { return ra::iota(maybe_len(e), ITEM(0).i-ITEM(1).i, ITEM(0).s-ITEM(1).s); }
  149. template <is_iota I, iota_op J>
  150. constexpr auto
  151. optimize(Expr<std::multiplies<>, std::tuple<I, J>> && e)
  152. { return ra::iota(ITEM(0).n, ITEM(0).i*ITEM(1), ITEM(0).s*ITEM(1)); }
  153. template <iota_op I, is_iota J>
  154. constexpr auto
  155. optimize(Expr<std::multiplies<>, std::tuple<I, J>> && e)
  156. { return ra::iota(ITEM(1).n, ITEM(0)*ITEM(1).i, ITEM(0)*ITEM(1).s); }
  157. template <is_iota I>
  158. constexpr auto
  159. optimize(Expr<std::negate<>, std::tuple<I>> && e)
  160. { return ra::iota(ITEM(0).n, -ITEM(0).i, -ITEM(0).s); }
  161. #if RA_DO_OPT_SMALLVECTOR==1
  162. // FIXME can't match CellSmall directly, maybe bc N is in std::array { Dim { N, 1 } }.
  163. template <class T, dim_t N, class A> constexpr bool match_small =
  164. std::is_same_v<std::decay_t<A>, typename ra::Small<T, N>::View::iterator<0>>
  165. || std::is_same_v<std::decay_t<A>, typename ra::Small<T, N>::ViewConst::iterator<0>>;
  166. static_assert(match_small<double, 4, ra::Cell<double, ic_t<std::array { Dim { 4, 1 } }>, ic_t<0>>>);
  167. #define RA_OPT_SMALLVECTOR_OP(OP, NAME, T, N) \
  168. template <class A, class B> requires (match_small<T, N, A> && match_small<T, N, B>) \
  169. constexpr auto \
  170. optimize(ra::Expr<NAME, std::tuple<A, B>> && e) \
  171. { \
  172. alignas (alignof(extvector<T, N>)) ra::Small<T, N> val; \
  173. *(extvector<T, N> *)(&val) = *(extvector<T, N> *)((ITEM(0).c.cp)) OP *(extvector<T, N> *)((ITEM(1).c.cp)); \
  174. return val; \
  175. }
  176. #define RA_OPT_SMALLVECTOR_OP_FUNS(T, N) \
  177. static_assert(0==alignof(ra::Small<T, N>) % alignof(extvector<T, N>)); \
  178. RA_OPT_SMALLVECTOR_OP(+, std::plus<>, T, N) \
  179. RA_OPT_SMALLVECTOR_OP(-, std::minus<>, T, N) \
  180. RA_OPT_SMALLVECTOR_OP(/, std::divides<>, T, N) \
  181. RA_OPT_SMALLVECTOR_OP(*, std::multiplies<>, T, N)
  182. #define RA_OPT_SMALLVECTOR_OP_SIZES(T) \
  183. RA_OPT_SMALLVECTOR_OP_FUNS(T, 2) \
  184. RA_OPT_SMALLVECTOR_OP_FUNS(T, 4) \
  185. RA_OPT_SMALLVECTOR_OP_FUNS(T, 8)
  186. FOR_EACH(RA_OPT_SMALLVECTOR_OP_SIZES, float, double)
  187. #undef RA_OPT_SMALLVECTOR_OP_SIZES
  188. #undef RA_OPT_SMALLVECTOR_OP_FUNS
  189. #undef RA_OPT_SMALLVECTOR_OP_OP
  190. #endif // RA_DO_OPT_SMALLVECTOR
  191. #undef ITEM
  192. // --------------------------------
  193. // array versions of operators and functions
  194. // --------------------------------
  195. // We need zero/scalar specializations because the scalar/scalar operators maybe be templated (e.g. complex<>), so they won't be found when an implicit scalar conversion is also needed, and e.g. ra::View<complex, 0> * complex would fail.
  196. // The function objects are matched in optimize.hh.
  197. #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
  198. template <class A, class B> requires (tomap<A, B>) constexpr auto \
  199. operator OP(A && a, B && b) { return RA_OPT(map(OPNAME(), RA_FWD(a), RA_FWD(b))); } \
  200. template <class A, class B> requires (toreduce<A, B>) constexpr auto \
  201. operator OP(A && a, B && b) { return VALUE(RA_FWD(a)) OP VALUE(RA_FWD(b)); }
  202. DEF_NAMED_BINARY_OP(+, std::plus<>) DEF_NAMED_BINARY_OP(-, std::minus<>)
  203. DEF_NAMED_BINARY_OP(*, std::multiplies<>) DEF_NAMED_BINARY_OP(/, std::divides<>)
  204. DEF_NAMED_BINARY_OP(==, std::equal_to<>) DEF_NAMED_BINARY_OP(>, std::greater<>)
  205. DEF_NAMED_BINARY_OP(<, std::less<>) DEF_NAMED_BINARY_OP(>=, std::greater_equal<>)
  206. DEF_NAMED_BINARY_OP(<=, std::less_equal<>) DEF_NAMED_BINARY_OP(!=, std::not_equal_to<>)
  207. DEF_NAMED_BINARY_OP(|, std::bit_or<>) DEF_NAMED_BINARY_OP(&, std::bit_and<>)
  208. DEF_NAMED_BINARY_OP(^, std::bit_xor<>) DEF_NAMED_BINARY_OP(<=>, std::compare_three_way)
  209. #undef DEF_NAMED_BINARY_OP
  210. // FIXME address sanitizer complains in bench-optimize.cc if we use std::identity. Maybe false positive
  211. struct unaryplus
  212. {
  213. template <class T> constexpr static auto operator()(T && t) noexcept { return RA_FWD(t); }
  214. };
  215. #define DEF_NAMED_UNARY_OP(OP, OPNAME) \
  216. template <class A> requires (tomap<A>) constexpr auto \
  217. operator OP(A && a) { return map(OPNAME(), RA_FWD(a)); } \
  218. template <class A> requires (toreduce<A>) constexpr auto \
  219. operator OP(A && a) { return OP VALUE(RA_FWD(a)); }
  220. DEF_NAMED_UNARY_OP(+, unaryplus)
  221. DEF_NAMED_UNARY_OP(-, std::negate<>)
  222. DEF_NAMED_UNARY_OP(!, std::logical_not<>)
  223. #undef DEF_NAMED_UNARY_OP
  224. // if OP(a) isn't found in ra::, deduction rank(0) -> scalar doesn't work. TODO Cf useret.cc, reexported.cc
  225. #define DEF_NAME(OP) \
  226. template <class ... A> requires (tomap<A ...>) constexpr auto \
  227. OP(A && ... a) { return map([](auto && ... a) -> decltype(auto) { return OP(RA_FWD(a) ...); }, RA_FWD(a) ...); } \
  228. template <class ... A> requires (toreduce<A ...>) constexpr decltype(auto) \
  229. OP(A && ... a) { return OP(VALUE(RA_FWD(a)) ...); }
  230. #define DEF_FWD(QUALIFIED_OP, OP) \
  231. template <class ... A> requires (!tomap<A ...> && !toreduce<A ...>) constexpr decltype(auto) \
  232. OP(A && ... a) { return QUALIFIED_OP(RA_FWD(a) ...); } \
  233. DEF_NAME(OP)
  234. #define DEF_USING(QUALIFIED_OP, OP) \
  235. using QUALIFIED_OP; \
  236. DEF_NAME(OP)
  237. FOR_EACH(DEF_NAME, odd, arg, sqr, sqrm, real_part, imag_part, xi, rel_error)
  238. // can't DEF_USING bc std::max will gobble ra:: objects if passed by const & (!)
  239. // FIXME define own global max/min overloads for basic types. std::max seems too much of a special case to be usinged.
  240. #define DEF_GLOBAL(f) DEF_FWD(::f, f)
  241. FOR_EACH(DEF_GLOBAL, max, min)
  242. #undef DEF_GLOBAL
  243. // don't use DEF_FWD for these bc we want to allow ADL, e.g. for exp(dual).
  244. #define DEF_GLOBAL(f) DEF_USING(::f, f)
  245. FOR_EACH(DEF_GLOBAL, pow, conj, sqrt, exp, expm1, log, log1p, log10, isfinite, isnan, isinf, atan2)
  246. FOR_EACH(DEF_GLOBAL, abs, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, clamp, lerp)
  247. FOR_EACH(DEF_GLOBAL, fma)
  248. #undef DEF_GLOBAL
  249. #undef DEF_USING
  250. #undef DEF_FWD
  251. #undef DEF_NAME
  252. template <class T, class A>
  253. constexpr auto
  254. cast(A && a)
  255. {
  256. return map([](auto && b) -> decltype(auto) { return T(b); }, RA_FWD(a));
  257. }
  258. template <class T, class ... A>
  259. constexpr auto
  260. pack(A && ... a)
  261. {
  262. return map([](auto && ... a) { return T { a ... }; }, RA_FWD(a) ...);
  263. }
  264. // FIXME needs nested array for I, but iter<-1> should work
  265. template <class A, class I>
  266. constexpr auto
  267. at(A && a, I && i)
  268. {
  269. return map([a = std::tuple<A>(RA_FWD(a))] (auto && i) -> decltype(auto) { return std::get<0>(a).at(i); },
  270. RA_FWD(i));
  271. }
  272. // --------------------------------
  273. // selection / shortcutting
  274. // --------------------------------
  275. // ra::start are needed bc rank 0 converts to and from scalar, so ? can't pick the right (-> scalar) conversion.
  276. template <class T, class F> requires (toreduce<T, F>)
  277. constexpr decltype(auto)
  278. where(bool const w, T && t, F && f)
  279. {
  280. return w ? VALUE(t) : VALUE(f);
  281. }
  282. template <class W, class T, class F> requires (tomap<W, T, F>)
  283. constexpr auto
  284. where(W && w, T && t, F && f)
  285. {
  286. return pick(cast<bool>(RA_FWD(w)), RA_FWD(f), RA_FWD(t));
  287. }
  288. // catch all for non-ra types.
  289. template <class T, class F> requires (!(tomap<T, F>) && !(toreduce<T, F>))
  290. constexpr decltype(auto)
  291. where(bool const w, T && t, F && f)
  292. {
  293. return w ? t : f;
  294. }
  295. template <class A, class B> requires (tomap<A, B>)
  296. constexpr auto
  297. operator &&(A && a, B && b)
  298. {
  299. return where(RA_FWD(a), cast<bool>(RA_FWD(b)), false);
  300. }
  301. template <class A, class B> requires (tomap<A, B>)
  302. constexpr auto
  303. operator ||(A && a, B && b)
  304. {
  305. return where(RA_FWD(a), true, cast<bool>(RA_FWD(b)));
  306. }
  307. #define DEF_SHORTCIRCUIT_BINARY_OP(OP) \
  308. template <class A, class B> requires (toreduce<A, B>) \
  309. constexpr auto operator OP(A && a, B && b) \
  310. { \
  311. return VALUE(a) OP VALUE(b); \
  312. }
  313. FOR_EACH(DEF_SHORTCIRCUIT_BINARY_OP, &&, ||)
  314. #undef DEF_SHORTCIRCUIT_BINARY_OP
  315. // --------------------------------
  316. // whole-array ops. TODO First/variable rank reductions? FIXME C++23 and_then/or_else/etc
  317. // --------------------------------
  318. constexpr bool
  319. any(auto && a)
  320. {
  321. return early(map([](bool x) { return x ? std::make_optional(true) : std::nullopt; }, RA_FWD(a)), false);
  322. }
  323. constexpr bool
  324. every(auto && a)
  325. {
  326. return early(map([](bool x) { return !x ? std::make_optional(false) : std::nullopt; }, RA_FWD(a)), true);
  327. }
  328. // FIXME variable rank? see J 'index of' (x i. y), etc.
  329. constexpr dim_t
  330. index(auto && a)
  331. {
  332. return early(map([](auto && a, auto && i) { return bool(a) ? std::make_optional(i) : std::nullopt; },
  333. RA_FWD(a), ra::iota(ra::start(a).len(0))),
  334. ra::dim_t(-1));
  335. }
  336. constexpr bool
  337. lexicographical_compare(auto && a, auto && b)
  338. {
  339. return early(map([](auto && a, auto && b) { return a==b ? std::nullopt : std::make_optional(a<b); },
  340. RA_FWD(a), RA_FWD(b)),
  341. false);
  342. }
  343. constexpr auto
  344. amin(auto && a)
  345. {
  346. using std::min, std::numeric_limits;
  347. using T = ncvalue_t<decltype(a)>;
  348. T c = numeric_limits<T>::has_infinity ? numeric_limits<T>::infinity() : numeric_limits<T>::max();
  349. for_each([&c](auto && a) { if (a<c) { c = a; } }, a);
  350. return c;
  351. }
  352. constexpr auto
  353. amax(auto && a)
  354. {
  355. using std::max, std::numeric_limits;
  356. using T = ncvalue_t<decltype(a)>;
  357. T c = numeric_limits<T>::has_infinity ? -numeric_limits<T>::infinity() : numeric_limits<T>::lowest();
  358. for_each([&c](auto && a) { if (c<a) { c = a; } }, a);
  359. return c;
  360. }
  361. // FIXME encapsulate this kind of reference-reduction.
  362. // FIXME expr/ply mechanism doesn't allow partial iteration (adv then continue).
  363. template <class A, class Less = std::less<ncvalue_t<A>>>
  364. constexpr decltype(auto)
  365. refmin(A && a, Less && less = {})
  366. {
  367. RA_CHECK(a.size()>0, "refmin requires nonempty argument.");
  368. decltype(auto) s = ra::start(a);
  369. auto p = &(*s);
  370. for_each([&less, &p](auto & a) { if (less(a, *p)) { p = &a; } }, s);
  371. return *p;
  372. }
  373. template <class A, class Less = std::less<ncvalue_t<A>>>
  374. constexpr decltype(auto)
  375. refmax(A && a, Less && less = {})
  376. {
  377. RA_CHECK(a.size()>0, "refmax requires nonempty argument.");
  378. decltype(auto) s = ra::start(a);
  379. auto p = &(*s);
  380. for_each([&less, &p](auto & a) { if (less(*p, a)) { p = &a; } }, s);
  381. return *p;
  382. }
  383. constexpr auto
  384. sum(auto && a)
  385. {
  386. auto c = concrete_type<ncvalue_t<decltype(a)>>(0);
  387. for_each([&c](auto && a) { c += a; }, a);
  388. return c;
  389. }
  390. constexpr auto
  391. prod(auto && a)
  392. {
  393. auto c = concrete_type<ncvalue_t<decltype(a)>>(1);
  394. for_each([&c](auto && a) { c *= a; }, a);
  395. return c;
  396. }
  397. constexpr void maybe_fma(auto && a, auto && b, auto & c) { if constexpr (RA_DO_FMA) c = fma(a, b, c); else c += a*b; }
  398. constexpr void maybe_fma_conj(auto && a, auto && b, auto & c) { if constexpr (RA_DO_FMA) c = fma_conj(a, b, c); else c += conj(a)*b; }
  399. constexpr void maybe_fma_sqrm(auto && a, auto & c) { if constexpr (RA_DO_FMA) c = fma_sqrm(a, c); else c += sqrm(a); }
  400. constexpr auto
  401. dot(auto && a, auto && b)
  402. {
  403. std::decay_t<decltype(VALUE(a) * VALUE(b))> c(0.);
  404. for_each([&c](auto && a, auto && b) { maybe_fma(a, b, c); }, RA_FWD(a), RA_FWD(b));
  405. return c;
  406. }
  407. constexpr auto
  408. cdot(auto && a, auto && b)
  409. {
  410. std::decay_t<decltype(conj(VALUE(a)) * VALUE(b))> c(0.);
  411. for_each([&c](auto && a, auto && b) { maybe_fma_conj(a, b, c); }, RA_FWD(a), RA_FWD(b));
  412. return c;
  413. }
  414. constexpr auto
  415. reduce_sqrm(auto && a)
  416. {
  417. std::decay_t<decltype(sqrm(VALUE(a)))> c(0.);
  418. for_each([&c](auto && a) { maybe_fma_sqrm(a, c); }, RA_FWD(a));
  419. return c;
  420. }
  421. constexpr auto norm2(auto && a) { return std::sqrt(reduce_sqrm(a)); }
  422. constexpr auto
  423. normv(auto const & a)
  424. {
  425. auto b = concrete(a);
  426. return b /= norm2(b);
  427. }
  428. // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it (see bench-gemm.cc)
  429. constexpr void
  430. gemm(auto const & a, auto const & b, auto & c)
  431. {
  432. dim_t K=a.len(1);
  433. for (int k=0; k<K; ++k) {
  434. c += from(std::multiplies<>(), a(all, k), b(k)); // FIXME fma
  435. }
  436. }
  437. constexpr auto
  438. gemm(auto const & a, auto const & b)
  439. {
  440. dim_t M=a.len(0), N=b.len(1);
  441. using T = decltype(VALUE(a)*VALUE(b));
  442. using MMTYPE = decltype(from(std::multiplies<>(), a(all, 0), b(0)));
  443. auto c = with_shape<MMTYPE>({M, N}, T());
  444. gemm(a, b, c);
  445. return c;
  446. }
  447. constexpr auto
  448. gevm(auto const & a, auto const & b)
  449. {
  450. dim_t M=b.len(0), N=b.len(1);
  451. using T = decltype(VALUE(a)*VALUE(b));
  452. auto c = with_shape<decltype(a[0]*b(0))>({N}, T());
  453. for (int i=0; i<M; ++i) {
  454. maybe_fma(a[i], b(i), c);
  455. }
  456. return c;
  457. }
  458. // FIXME a must be a view, so it doesn't work with e.g. gemv(conj(a), b).
  459. constexpr auto
  460. gemv(auto const & a, auto const & b)
  461. {
  462. dim_t M=a.len(0), N=a.len(1);
  463. using T = decltype(VALUE(a)*VALUE(b));
  464. auto c = with_shape<decltype(a(all, 0)*b[0])>({M}, T());
  465. for (int j=0; j<N; ++j) {
  466. maybe_fma(a(all, j), b[j], c);
  467. }
  468. return c;
  469. }
  470. // --------------------
  471. // wedge product and cross product. FIXME seriously
  472. // --------------------
  473. namespace mp {
  474. template <class P, class Plist>
  475. struct FindCombination
  476. {
  477. template <class A> using match = bool_c<0 != PermutationSign<P, A>::value>;
  478. using type = IndexIf<Plist, match>;
  479. constexpr static int where = type::value;
  480. constexpr static int sign = (where>=0) ? PermutationSign<P, typename type::type>::value : 0;
  481. };
  482. // Combination antiC complementary to C wrt [0, 1, ... Dim-1], permuted so [C, antiC] has the sign of [0, 1, ... Dim-1].
  483. template <class C, int D>
  484. struct AntiCombination
  485. {
  486. using EC = complement<C, D>;
  487. static_assert((len<EC>)>=2, "can't correct this complement");
  488. constexpr static int sign = PermutationSign<append<C, EC>, iota<D>>::value;
  489. // Produce permutation of opposite sign if sign<0.
  490. using type = mp::cons<std::tuple_element_t<(sign<0) ? 1 : 0, EC>,
  491. mp::cons<std::tuple_element_t<(sign<0) ? 0 : 1, EC>,
  492. mp::drop<EC, 2>>>;
  493. };
  494. template <class C, int D> struct MapAntiCombination;
  495. template <int D, class ... C>
  496. struct MapAntiCombination<std::tuple<C ...>, D>
  497. {
  498. using type = std::tuple<typename AntiCombination<C, D>::type ...>;
  499. };
  500. template <int D, int O>
  501. struct ChooseComponents_
  502. {
  503. static_assert(D>=O, "Bad dimension or form order.");
  504. using type = mp::combinations<iota<D>, O>;
  505. };
  506. template <int D, int O> using ChooseComponents = typename ChooseComponents_<D, O>::type;
  507. template <int D, int O> requires ((D>1) && (2*O>D))
  508. struct ChooseComponents_<D, O>
  509. {
  510. static_assert(D>=O, "Bad dimension or form order.");
  511. using type = typename MapAntiCombination<ChooseComponents<D, D-O>, D>::type;
  512. };
  513. // Up to (62 x) or (63 28) ~ 2^59 on 64 bit size_t.
  514. constexpr std::size_t
  515. binom(std::size_t n, std::size_t p)
  516. {
  517. if (p>n) { return 0; }
  518. if (p>(n-p)) { p=n-p; }
  519. std::size_t v = 1;
  520. for (std::size_t i=0; i<p; ++i) { v = v*(n-i)/(i+1); }
  521. return v;
  522. }
  523. // We form the basis for the result (Cr) and split it in pieces for Oa and Ob; there are (D over Oa) ways. Then we see where and with which signs these pieces are in the bases for Oa (Ca) and Ob (Cb), and form the product.
  524. template <int D, int Oa, int Ob>
  525. struct Wedge
  526. {
  527. constexpr static int Or = Oa+Ob;
  528. static_assert(Oa<=D && Ob<=D && Or<=D, "bad orders");
  529. constexpr static int Na = binom(D, Oa);
  530. constexpr static int Nb = binom(D, Ob);
  531. constexpr static int Nr = binom(D, Or);
  532. // in lexicographic order. Can be used to sort Ca below with FindPermutation.
  533. using LexOrCa = mp::combinations<mp::iota<D>, Oa>;
  534. // the actual components used, which are in lex. order only in some cases.
  535. using Ca = mp::ChooseComponents<D, Oa>;
  536. using Cb = mp::ChooseComponents<D, Ob>;
  537. using Cr = mp::ChooseComponents<D, Or>;
  538. // optimizations.
  539. constexpr static bool yields_expr = (Na>1) != (Nb>1);
  540. constexpr static bool yields_expr_a1 = yields_expr && Na==1;
  541. constexpr static bool yields_expr_b1 = yields_expr && Nb==1;
  542. constexpr static bool both_scalars = (Na==1 && Nb==1);
  543. constexpr static bool dot_plus = Na>1 && Nb>1 && Or==D && (Oa<Ob || (Oa>Ob && !ra::odd(Oa*Ob)));
  544. constexpr static bool dot_minus = Na>1 && Nb>1 && Or==D && (Oa>Ob && ra::odd(Oa*Ob));
  545. constexpr static bool other = (Na>1 && Nb>1) && ((Oa+Ob!=D) || (Oa==Ob));
  546. template <class Va, class Vb>
  547. using valtype = decltype(std::declval<Va>()[0] * std::declval<Vb>()[0]);
  548. template <class Xr, class Fa, class Va, class Vb>
  549. constexpr static valtype<Va, Vb>
  550. term(Va const & a, Vb const & b)
  551. {
  552. if constexpr (mp::len<Fa> > 0) {
  553. using Fa0 = mp::first<Fa>;
  554. using Fb = mp::complement_list<Fa0, Xr>;
  555. using Sa = mp::FindCombination<Fa0, Ca>;
  556. using Sb = mp::FindCombination<Fb, Cb>;
  557. constexpr int sign = Sa::sign * Sb::sign * mp::PermutationSign<mp::append<Fa0, Fb>, Xr>::value;
  558. static_assert(sign==+1 || sign==-1, "Bad sign in wedge term.");
  559. return valtype<Va, Vb>(sign)*a[Sa::where]*b[Sb::where] + term<Xr, mp::drop1<Fa>>(a, b);
  560. } else {
  561. return 0.;
  562. }
  563. }
  564. template <class Va, class Vb, class Vr, int wr>
  565. constexpr static void
  566. coeff(Va const & a, Vb const & b, Vr & r)
  567. {
  568. if constexpr (wr<Nr) {
  569. using Xr = mp::ref<Cr, wr>;
  570. using Fa = mp::combinations<Xr, Oa>;
  571. r[wr] = term<Xr, Fa>(a, b);
  572. coeff<Va, Vb, Vr, wr+1>(a, b, r);
  573. }
  574. }
  575. template <class Va, class Vb, class Vr>
  576. constexpr static void
  577. product(Va const & a, Vb const & b, Vr & r)
  578. {
  579. static_assert(Va::size()==Na, "Bad Va dim.");
  580. static_assert(Vb::size()==Nb, "Bad Vb dim.");
  581. static_assert(Vr::size()==Nr, "Bad Vr dim.");
  582. coeff<Va, Vb, Vr, 0>(a, b, r);
  583. }
  584. };
  585. // Euclidean space, only component shuffling.
  586. template <int D, int O>
  587. struct Hodge
  588. {
  589. using W = Wedge<D, O, D-O>;
  590. using Ca = typename W::Ca;
  591. using Cb = typename W::Cb;
  592. using Cr = typename W::Cr;
  593. using LexOrCa = typename W::LexOrCa;
  594. constexpr static int Na = W::Na;
  595. constexpr static int Nb = W::Nb;
  596. template <int i, class Va, class Vb>
  597. constexpr static void
  598. hodge_aux(Va const & a, Vb & b)
  599. {
  600. static_assert(i<=W::Na, "Bad argument to hodge_aux");
  601. if constexpr (i<W::Na) {
  602. using Cai = mp::ref<Ca, i>;
  603. static_assert(mp::len<Cai> == O, "Bad.");
  604. // sort Cai, because mp::complement only accepts sorted combinations.
  605. // ref<Cb, i> should be complementary to Cai, but I don't want to rely on that.
  606. using SCai = mp::ref<LexOrCa, mp::FindCombination<Cai, LexOrCa>::where>;
  607. using CompCai = mp::complement<SCai, D>;
  608. static_assert(mp::len<CompCai> == D-O, "Bad.");
  609. using fpw = mp::FindCombination<CompCai, Cb>;
  610. // for the sign see e.g. DoCarmo1991 I.Ex 10.
  611. using fps = mp::FindCombination<mp::append<Cai, mp::ref<Cb, fpw::where>>, Cr>;
  612. static_assert(fps::sign!=0, "Bad.");
  613. b[fpw::where] = decltype(a[i])(fps::sign)*a[i];
  614. hodge_aux<i+1>(a, b);
  615. }
  616. }
  617. };
  618. // The order of components is taken from Wedge<D, O, D-O>; this works for whatever order is defined there.
  619. // With lexicographic order, component order is reversed, but signs vary.
  620. // With the order given by ChooseComponents<>, fpw::where==i and fps::sign==+1 in hodge_aux(), always. Then hodge() becomes a free operation, (with one exception) and the next function hodge() can be used.
  621. template <int D, int O, class Va, class Vb>
  622. constexpr void
  623. hodgex(Va const & a, Vb & b)
  624. {
  625. static_assert(O<=D, "bad orders");
  626. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  627. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  628. mp::Hodge<D, O>::template hodge_aux<0>(a, b);
  629. }
  630. } // namespace ra::mp
  631. // This depends on Wedge<>::Ca, Cb, Cr coming from ChooseCombinations. hodgex() should always work, but this is cheaper.
  632. // However if 2*O=D, it is not possible to differentiate the bases by order and hodgex() must be used.
  633. // Likewise, when O(N-O) is odd, Hodge from (2*O>D) to (2*O<D) change sign, since **w= -w in that case, and the basis in the (2*O>D) case is selected to make Hodge(<)->Hodge(>) trivial; but can't do both!
  634. consteval bool trivial_hodge(int D, int O) { return 2*O!=D && ((2*O<D) || !ra::odd(O*(D-O))); }
  635. template <int D, int O, class Va, class Vb>
  636. constexpr void
  637. hodge(Va const & a, Vb & b)
  638. {
  639. if constexpr (trivial_hodge(D, O)) {
  640. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  641. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  642. b = a;
  643. } else {
  644. ra::mp::hodgex<D, O>(a, b);
  645. }
  646. }
  647. template <int D, int O, class Va> requires (trivial_hodge(D, O))
  648. constexpr Va const &
  649. hodge(Va const & a)
  650. {
  651. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  652. return a;
  653. }
  654. template <int D, int O, class Va> requires (!trivial_hodge(D, O))
  655. constexpr Va &
  656. hodge(Va & a)
  657. {
  658. Va b(a);
  659. ra::mp::hodgex<D, O>(b, a);
  660. return a;
  661. }
  662. template <int D, int Oa, int Ob, class A, class B> requires (is_scalar<A> && is_scalar<B>)
  663. constexpr auto
  664. wedge(A const & a, B const & b) { return a*b; }
  665. template <class A>
  666. using torank1 = std::conditional_t<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
  667. template <int D, int Oa, int Ob, class Va, class Vb> requires (!(is_scalar<Va> && is_scalar<Vb>))
  668. decltype(auto)
  669. wedge(Va const & a, Vb const & b)
  670. {
  671. Small<ncvalue_t<Va>, size_s(a)> aa = a;
  672. Small<ncvalue_t<Vb>, size_s(b)> bb = b;
  673. using Ua = decltype(aa);
  674. using Ub = decltype(bb);
  675. using Wedge = mp::Wedge<D, Oa, Ob>;
  676. using valtype = typename Wedge::template valtype<Ua, Ub>;
  677. std::conditional_t<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr>> r;
  678. auto & a1 = reinterpret_cast<torank1<Ua> const &>(aa);
  679. auto & b1 = reinterpret_cast<torank1<Ub> const &>(bb);
  680. auto & r1 = reinterpret_cast<torank1<decltype(r)> &>(r);
  681. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  682. return r;
  683. }
  684. template <int D, int Oa, int Ob, class Va, class Vb, class Vr> requires (!(is_scalar<Va> && is_scalar<Vb>))
  685. void
  686. wedge(Va const & a, Vb const & b, Vr & r)
  687. {
  688. Small<ncvalue_t<Va>, size_s(a)> aa = a;
  689. Small<ncvalue_t<Vb>, size_s(b)> bb = b;
  690. using Ua = decltype(aa);
  691. using Ub = decltype(bb);
  692. auto & r1 = reinterpret_cast<torank1<decltype(r)> &>(r);
  693. auto & a1 = reinterpret_cast<torank1<Ua> const &>(aa);
  694. auto & b1 = reinterpret_cast<torank1<Ub> const &>(bb);
  695. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  696. }
  697. template <class A, class B>
  698. constexpr auto
  699. cross(A const & a_, B const & b_)
  700. {
  701. constexpr int n = size_s(a_);
  702. static_assert(n==size_s(b_) && (2==n || 3==n));
  703. Small<std::decay_t<decltype(VALUE(a_))>, n> a = a_;
  704. Small<std::decay_t<decltype(VALUE(b_))>, n> b = b_;
  705. using W = mp::Wedge<n, 1, 1>;
  706. Small<std::decay_t<decltype(VALUE(a_) * VALUE(b_))>, W::Nr> r;
  707. W::product(a, b, r);
  708. if constexpr (1==W::Nr) {
  709. return r[0];
  710. } else {
  711. return r;
  712. }
  713. }
  714. template <class V>
  715. constexpr auto
  716. perp(V const & v)
  717. {
  718. static_assert(2==v.size(), "Dimension error.");
  719. return Small<std::decay_t<decltype(VALUE(v))>, 2> {v[1], -v[0]};
  720. }
  721. template <class V, class U>
  722. constexpr auto
  723. perp(V const & v, U const & n)
  724. {
  725. if constexpr (is_scalar<U>) {
  726. static_assert(2==v.size(), "Dimension error.");
  727. return Small<std::decay_t<decltype(VALUE(v) * n)>, 2> {v[1]*n, -v[0]*n};
  728. } else {
  729. static_assert(3==v.size(), "Dimension error.");
  730. return cross(v, n);
  731. }
  732. }
  733. } // namespace ra
  734. #undef RA_OPT