ra.hh 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Operator overloads for expression templates, and root header.
  3. // (c) Daniel Llorens - 2014-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "big.hh"
  10. #include <cmath>
  11. #include <complex>
  12. #ifndef RA_DO_OPT
  13. #define RA_DO_OPT 1 // enabled by default
  14. #endif
  15. #if RA_DO_OPT==1
  16. #define RA_OPT optimize
  17. #else
  18. #define RA_OPT
  19. #endif
  20. #if defined(RA_DO_FMA)
  21. #elif defined(FP_FAST_FMA)
  22. #define RA_DO_FMA FP_FAST_FMA
  23. #else
  24. #define RA_DO_FMA 0
  25. #endif
  26. // Enable ADL with explicit template args. See http://stackoverflow.com/questions/9838862.
  27. template <class A> constexpr void transpose(ra::noarg);
  28. template <int A> constexpr void iter(ra::noarg);
  29. template <class T> constexpr void cast(ra::noarg);
  30. // ---------------------------
  31. // scalar overloads
  32. // ---------------------------
  33. // abs() needs no qualifying for ra:: types (ADL), shouldn't need it on pods either. FIXME maybe let user decide.
  34. // std::max/min are special, see DEF_NAME in ra.hh.
  35. using std::max, std::min, std::abs, std::fma, std::sqrt, std::pow, std::exp, std::swap,
  36. std::isfinite, std::isinf, std::isnan, std::clamp, std::lerp, std::conj, std::expm1;
  37. #define FOR_FLOAT(T) \
  38. constexpr T conj(T x) { return x; } \
  39. FOR_EACH(FOR_FLOAT, float, double)
  40. #undef FOR_FLOAT
  41. #define FOR_FLOAT(R) \
  42. constexpr std::complex<R> \
  43. fma(std::complex<R> const & a, std::complex<R> const & b, std::complex<R> const & c) \
  44. { \
  45. return std::complex<R>(fma(a.real(), b.real(), fma(-a.imag(), b.imag(), c.real())), \
  46. fma(a.real(), b.imag(), fma(a.imag(), b.real(), c.imag()))); \
  47. } \
  48. constexpr bool isfinite(std::complex<R> z) { return isfinite(z.real()) && isfinite(z.imag()); } \
  49. constexpr bool isnan(std::complex<R> z) { return isnan(z.real()) || isnan(z.imag()); } \
  50. constexpr bool isinf(std::complex<R> z) { return (isinf(z.real()) || isinf(z.imag())) && !isnan(z); }
  51. FOR_EACH(FOR_FLOAT, float, double)
  52. #undef FOR_FLOAT
  53. namespace ra {
  54. // As an array op; special definitions for rank 0.
  55. template <class T> constexpr bool ra_is_real = std::numeric_limits<T>::is_integer || std::is_floating_point_v<T>;
  56. template <class T> requires (ra_is_real<T>) constexpr T amax(T const & x) { return x; }
  57. template <class T> requires (ra_is_real<T>) constexpr T amin(T const & x) { return x; }
  58. template <class T> requires (ra_is_real<T>) constexpr T sqr(T const & x) { return x*x; }
  59. #define RA_FOR_TYPES(T) \
  60. constexpr T arg(T x) { return T(0); } \
  61. constexpr T conj(T x) { return x; } \
  62. constexpr T mul_conj(T x, T y) { return x*y; } \
  63. constexpr T sqrm(T x) { return sqr(x); } \
  64. constexpr T sqrm(T x, T y) { return sqr(x-y); } \
  65. constexpr T dot(T x, T y) { return x*y; } \
  66. constexpr T fma_conj(T a, T b, T c) { return fma(a, b, c); } \
  67. constexpr T norm2(T x) { return std::abs(x); } \
  68. constexpr T norm2(T x, T y) { return std::abs(x-y); } \
  69. constexpr T rel_error(T a, T b) { auto den = (abs(a)+abs(b)); return den==0 ? 0. : 2.*norm2(a, b)/den; } \
  70. constexpr T const & real_part(T const & x) { return x; } \
  71. constexpr T & real_part(T & x) { return x; } \
  72. constexpr T imag_part(T x) { return T(0); }
  73. FOR_EACH(RA_FOR_TYPES, float, double)
  74. #undef RA_FOR_TYPES
  75. // FIXME few still inline should eventually be constexpr.
  76. #define RA_FOR_TYPES(R, C) \
  77. inline R arg(C x) { return std::arg(x); } \
  78. constexpr C sqr(C x) { return x*x; } \
  79. constexpr C dot(C x, C y) { return x*y; } \
  80. constexpr C xi(R x) { return C(0, x); } \
  81. constexpr C xi(C z) { return C(-z.imag(), z.real()); } \
  82. constexpr R real_part(C const & z) { return z.real(); } \
  83. constexpr R imag_part(C const & z) { return z.imag(); } \
  84. inline R & real_part(C & z) { return reinterpret_cast<R *>(&z)[0]; } \
  85. inline R & imag_part(C & z) { return reinterpret_cast<R *>(&z)[1]; } \
  86. constexpr R sqrm(C x) { return sqr(x.real())+sqr(x.imag()); } \
  87. constexpr R sqrm(C x, C y) { return sqr(x.real()-y.real())+sqr(x.imag()-y.imag()); } \
  88. constexpr R norm2(C x) { return hypot(x.real(), x.imag()); } \
  89. constexpr R norm2(C x, C y) { return sqrt(sqrm(x, y)); } \
  90. inline R rel_error(C a, C b) { auto den = (abs(a)+abs(b)); return den==0 ? 0. : 2.*norm2(a, b)/den; } \
  91. /* conj(a) * b + c */ \
  92. constexpr C \
  93. fma_conj(C const & a, C const & b, C const & c) \
  94. { \
  95. return C(fma(a.real(), b.real(), fma(a.imag(), b.imag(), c.real())), \
  96. fma(a.real(), b.imag(), fma(-a.imag(), b.real(), c.imag()))); \
  97. } \
  98. /* conj(a) * b */ \
  99. constexpr C \
  100. mul_conj(C const & a, C const & b) \
  101. { \
  102. return C(a.real()*b.real()+a.imag()*b.imag(), \
  103. a.real()*b.imag()-a.imag()*b.real()); \
  104. }
  105. RA_FOR_TYPES(float, std::complex<float>)
  106. RA_FOR_TYPES(double, std::complex<double>)
  107. #undef RA_FOR_TYPES
  108. template <class T> constexpr bool is_scalar_def<std::complex<T>> = true;
  109. template <int ... Iarg, class A>
  110. constexpr decltype(auto)
  111. transpose(mp::int_list<Iarg ...>, A && a) { return transpose<Iarg ...>(RA_FWD(a)); }
  112. constexpr bool
  113. odd(unsigned int N) { return N & 1; }
  114. // --------------------------------
  115. // optimization pass over expression templates.
  116. // --------------------------------
  117. template <class E> constexpr decltype(auto) optimize(E && e) { return RA_FWD(e); }
  118. // FIXME only reduces iota exprs as operated on in ra.hh (operators), not a tree like WithLen does.
  119. #if RA_DO_OPT_IOTA==1
  120. // TODO maybe don't opt iota(int)*real -> iota(real) since a+a+... != n*a
  121. template <class X> concept iota_op = ra::is_zero_or_scalar<X> && std::is_arithmetic_v<ncvalue_t<X>>;
  122. // TODO something to handle the & variants...
  123. #define ITEM(i) std::get<(i)>(e.t)
  124. // FIXME gets() vs p2781r2
  125. // qualified ra::iota is necessary to avoid ADLing to std::iota (test/headers.cc).
  126. template <is_iota I, iota_op J>
  127. constexpr auto
  128. optimize(Expr<std::plus<>, std::tuple<I, J>> && e)
  129. { return ra::iota(ITEM(0).n, ITEM(0).i+ITEM(1), ITEM(0).s); }
  130. template <iota_op I, is_iota J>
  131. constexpr auto
  132. optimize(Expr<std::plus<>, std::tuple<I, J>> && e)
  133. { return ra::iota(ITEM(1).n, ITEM(0)+ITEM(1).i, ITEM(1).s); }
  134. template <is_iota I, is_iota J>
  135. constexpr auto
  136. optimize(Expr<std::plus<>, std::tuple<I, J>> && e)
  137. { return ra::iota(maybe_len(e), ITEM(0).i+ITEM(1).i, ITEM(0).s+ITEM(1).s); }
  138. template <is_iota I, iota_op J>
  139. constexpr auto
  140. optimize(Expr<std::minus<>, std::tuple<I, J>> && e)
  141. { return ra::iota(ITEM(0).n, ITEM(0).i-ITEM(1), ITEM(0).s); }
  142. template <iota_op I, is_iota J>
  143. constexpr auto
  144. optimize(Expr<std::minus<>, std::tuple<I, J>> && e)
  145. { return ra::iota(ITEM(1).n, ITEM(0)-ITEM(1).i, -ITEM(1).s); }
  146. template <is_iota I, is_iota J>
  147. constexpr auto
  148. optimize(Expr<std::minus<>, std::tuple<I, J>> && e)
  149. { return ra::iota(maybe_len(e), ITEM(0).i-ITEM(1).i, ITEM(0).s-ITEM(1).s); }
  150. template <is_iota I, iota_op J>
  151. constexpr auto
  152. optimize(Expr<std::multiplies<>, std::tuple<I, J>> && e)
  153. { return ra::iota(ITEM(0).n, ITEM(0).i*ITEM(1), ITEM(0).s*ITEM(1)); }
  154. template <iota_op I, is_iota J>
  155. constexpr auto
  156. optimize(Expr<std::multiplies<>, std::tuple<I, J>> && e)
  157. { return ra::iota(ITEM(1).n, ITEM(0)*ITEM(1).i, ITEM(0)*ITEM(1).s); }
  158. template <is_iota I>
  159. constexpr auto
  160. optimize(Expr<std::negate<>, std::tuple<I>> && e)
  161. { return ra::iota(ITEM(0).n, -ITEM(0).i, -ITEM(0).s); }
  162. #endif // RA_DO_OPT_IOTA
  163. #if RA_DO_OPT_SMALLVECTOR==1
  164. // FIXME can't match CellSmall directly, maybe bc N is in std::array { Dim { N, 1 } }.
  165. template <class T, dim_t N, class A> constexpr bool match_small =
  166. std::is_same_v<std::decay_t<A>, typename ra::Small<T, N>::View::iterator<0>>
  167. || std::is_same_v<std::decay_t<A>, typename ra::Small<T, N>::ViewConst::iterator<0>>;
  168. static_assert(match_small<double, 4, ra::Cell<double, ic_t<std::array { Dim { 4, 1 } }>, ic_t<0>>>);
  169. #define RA_OPT_SMALLVECTOR_OP(OP, NAME, T, N) \
  170. template <class A, class B> requires (match_small<T, N, A> && match_small<T, N, B>) \
  171. constexpr auto \
  172. optimize(ra::Expr<NAME, std::tuple<A, B>> && e) \
  173. { \
  174. alignas (alignof(extvector<T, N>)) ra::Small<T, N> val; \
  175. *(extvector<T, N> *)(&val) = *(extvector<T, N> *)((ITEM(0).c.cp)) OP *(extvector<T, N> *)((ITEM(1).c.cp)); \
  176. return val; \
  177. }
  178. #define RA_OPT_SMALLVECTOR_OP_FUNS(T, N) \
  179. static_assert(0==alignof(ra::Small<T, N>) % alignof(extvector<T, N>)); \
  180. RA_OPT_SMALLVECTOR_OP(+, std::plus<>, T, N) \
  181. RA_OPT_SMALLVECTOR_OP(-, std::minus<>, T, N) \
  182. RA_OPT_SMALLVECTOR_OP(/, std::divides<>, T, N) \
  183. RA_OPT_SMALLVECTOR_OP(*, std::multiplies<>, T, N)
  184. #define RA_OPT_SMALLVECTOR_OP_SIZES(T) \
  185. RA_OPT_SMALLVECTOR_OP_FUNS(T, 2) \
  186. RA_OPT_SMALLVECTOR_OP_FUNS(T, 4) \
  187. RA_OPT_SMALLVECTOR_OP_FUNS(T, 8)
  188. FOR_EACH(RA_OPT_SMALLVECTOR_OP_SIZES, float, double)
  189. #undef RA_OPT_SMALLVECTOR_OP_SIZES
  190. #undef RA_OPT_SMALLVECTOR_OP_FUNS
  191. #undef RA_OPT_SMALLVECTOR_OP_OP
  192. #endif // RA_DO_OPT_SMALLVECTOR
  193. #undef ITEM
  194. // --------------------------------
  195. // array versions of operators and functions
  196. // --------------------------------
  197. // We need zero/scalar specializations because the scalar/scalar operators maybe be templated (e.g. complex<>), so they won't be found when an implicit conversion to scalar is also needed, and e.g. ra::View<complex, 0> * complex would fail.
  198. // The function objects are matched in optimize.hh.
  199. #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
  200. template <class A, class B> requires (tomap<A, B>) constexpr auto \
  201. operator OP(A && a, B && b) \
  202. { return RA_OPT(map(OPNAME(), RA_FWD(a), RA_FWD(b))); } \
  203. template <class A, class B> requires (toreduce<A, B>) constexpr auto \
  204. operator OP(A && a, B && b) \
  205. { return VALUE(RA_FWD(a)) OP VALUE(RA_FWD(b)); }
  206. DEF_NAMED_BINARY_OP(+, std::plus<>) DEF_NAMED_BINARY_OP(-, std::minus<>)
  207. DEF_NAMED_BINARY_OP(*, std::multiplies<>) DEF_NAMED_BINARY_OP(/, std::divides<>)
  208. DEF_NAMED_BINARY_OP(==, std::equal_to<>) DEF_NAMED_BINARY_OP(>, std::greater<>)
  209. DEF_NAMED_BINARY_OP(<, std::less<>) DEF_NAMED_BINARY_OP(>=, std::greater_equal<>)
  210. DEF_NAMED_BINARY_OP(<=, std::less_equal<>) DEF_NAMED_BINARY_OP(!=, std::not_equal_to<>)
  211. DEF_NAMED_BINARY_OP(|, std::bit_or<>) DEF_NAMED_BINARY_OP(&, std::bit_and<>)
  212. DEF_NAMED_BINARY_OP(^, std::bit_xor<>) DEF_NAMED_BINARY_OP(<=>, std::compare_three_way)
  213. #undef DEF_NAMED_BINARY_OP
  214. // FIXME address sanitizer complains in bench-optimize.cc if we use std::identity. Maybe false positive
  215. struct unaryplus
  216. {
  217. template <class T> constexpr /* static P1169 in gcc13 */ auto
  218. operator()(T && t) const noexcept { return RA_FWD(t); }
  219. };
  220. #define DEF_NAMED_UNARY_OP(OP, OPNAME) \
  221. template <class A> requires (tomap<A>) constexpr auto \
  222. operator OP(A && a) \
  223. { return map(OPNAME(), RA_FWD(a)); } \
  224. template <class A> requires (toreduce<A>) constexpr auto \
  225. operator OP(A && a) \
  226. { return OP VALUE(RA_FWD(a)); }
  227. DEF_NAMED_UNARY_OP(+, unaryplus)
  228. DEF_NAMED_UNARY_OP(-, std::negate<>)
  229. DEF_NAMED_UNARY_OP(!, std::logical_not<>)
  230. #undef DEF_NAMED_UNARY_OP
  231. // if OP(a) isn't found in ra::, deduction rank(0) -> scalar doesn't work. TODO Cf useret.cc, reexported.cc
  232. #define DEF_NAME(OP) \
  233. template <class ... A> requires (tomap<A ...>) constexpr auto \
  234. OP(A && ... a) \
  235. { return map([](auto && ... a) -> decltype(auto) { return OP(RA_FWD(a) ...); }, RA_FWD(a) ...); } \
  236. template <class ... A> requires (toreduce<A ...>) constexpr decltype(auto) \
  237. OP(A && ... a) \
  238. { return OP(VALUE(RA_FWD(a)) ...); }
  239. #define DEF_FWD(QUALIFIED_OP, OP) \
  240. template <class ... A> requires (!tomap<A ...> && !toreduce<A ...>) constexpr decltype(auto) \
  241. OP(A && ... a) \
  242. { return QUALIFIED_OP(RA_FWD(a) ...); } \
  243. DEF_NAME(OP)
  244. #define DEF_USING(QUALIFIED_OP, OP) \
  245. using QUALIFIED_OP; \
  246. DEF_NAME(OP)
  247. FOR_EACH(DEF_NAME, odd, arg, sqr, sqrm, real_part, imag_part, xi, rel_error)
  248. // can't DEF_USING bc std::max will gobble ra:: objects if passed by const & (!)
  249. // FIXME define own global max/min overloads for basic types. std::max seems too much of a special case to be usinged.
  250. #define DEF_GLOBAL(f) DEF_FWD(::f, f)
  251. FOR_EACH(DEF_GLOBAL, max, min)
  252. #undef DEF_GLOBAL
  253. // don't use DEF_FWD for these bc we want to allow ADL, e.g. for exp(dual).
  254. #define DEF_GLOBAL(f) DEF_USING(::f, f)
  255. FOR_EACH(DEF_GLOBAL, pow, conj, sqrt, exp, expm1, log, log1p, log10, isfinite, isnan, isinf, atan2)
  256. FOR_EACH(DEF_GLOBAL, abs, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, clamp, lerp)
  257. FOR_EACH(DEF_GLOBAL, fma)
  258. #undef DEF_GLOBAL
  259. #undef DEF_USING
  260. #undef DEF_FWD
  261. #undef DEF_NAME
  262. template <class T, class A>
  263. constexpr auto
  264. cast(A && a)
  265. {
  266. return map([](auto && b) -> decltype(auto) { return T(b); }, RA_FWD(a));
  267. }
  268. template <class T, class ... A>
  269. constexpr auto
  270. pack(A && ... a)
  271. {
  272. return map([](auto && ... a) { return T { a ... }; }, RA_FWD(a) ...);
  273. }
  274. // FIXME needs nested array for I, but iter<-1> should work
  275. template <class A, class I>
  276. constexpr auto
  277. at(A && a, I && i)
  278. {
  279. return map([a = std::tuple<A>(RA_FWD(a))] (auto && i) -> decltype(auto) { return std::get<0>(a).at(i); },
  280. RA_FWD(i));
  281. }
  282. // --------------------------------
  283. // selection / shortcutting
  284. // --------------------------------
  285. // ra::start are needed bc rank 0 converts to and from scalar, so ? can't pick the right (-> scalar) conversion.
  286. template <class T, class F> requires (toreduce<T, F>)
  287. constexpr decltype(auto)
  288. where(bool const w, T && t, F && f)
  289. {
  290. return w ? VALUE(t) : VALUE(f);
  291. }
  292. template <class W, class T, class F> requires (tomap<W, T, F>)
  293. constexpr auto
  294. where(W && w, T && t, F && f)
  295. {
  296. return pick(cast<bool>(RA_FWD(w)), RA_FWD(f), RA_FWD(t));
  297. }
  298. // catch all for non-ra types.
  299. template <class T, class F> requires (!(tomap<T, F>) && !(toreduce<T, F>))
  300. constexpr decltype(auto)
  301. where(bool const w, T && t, F && f)
  302. {
  303. return w ? t : f;
  304. }
  305. template <class A, class B> requires (tomap<A, B>)
  306. constexpr auto
  307. operator &&(A && a, B && b)
  308. {
  309. return where(RA_FWD(a), cast<bool>(RA_FWD(b)), false);
  310. }
  311. template <class A, class B> requires (tomap<A, B>)
  312. constexpr auto
  313. operator ||(A && a, B && b)
  314. {
  315. return where(RA_FWD(a), true, cast<bool>(RA_FWD(b)));
  316. }
  317. #define DEF_SHORTCIRCUIT_BINARY_OP(OP) \
  318. template <class A, class B> requires (toreduce<A, B>) \
  319. constexpr auto operator OP(A && a, B && b) \
  320. { \
  321. return VALUE(a) OP VALUE(b); \
  322. }
  323. FOR_EACH(DEF_SHORTCIRCUIT_BINARY_OP, &&, ||)
  324. #undef DEF_SHORTCIRCUIT_BINARY_OP
  325. // --------------------------------
  326. // whole-array ops. TODO First rank reductions? Variable rank reductions?
  327. // FIXME C++23 and_then/or_else/etc
  328. // --------------------------------
  329. constexpr bool
  330. any(auto && a)
  331. {
  332. return early(map([](bool x) { return x ? std::make_optional(true) : std::nullopt; }, RA_FWD(a)), false);
  333. }
  334. constexpr bool
  335. every(auto && a)
  336. {
  337. return early(map([](bool x) { return !x ? std::make_optional(false) : std::nullopt; }, RA_FWD(a)), true);
  338. }
  339. // FIXME variable rank? see J 'index of' (x i. y), etc.
  340. constexpr dim_t
  341. index(auto && a)
  342. {
  343. return early(map([](auto && a, auto && i) { return bool(a) ? std::make_optional(i) : std::nullopt; },
  344. RA_FWD(a), ra::iota(ra::start(a).len(0))),
  345. ra::dim_t(-1));
  346. }
  347. constexpr bool
  348. lexicographical_compare(auto && a, auto && b)
  349. {
  350. return early(map([](auto && a, auto && b) { return a==b ? std::nullopt : std::make_optional(a<b); },
  351. RA_FWD(a), RA_FWD(b)),
  352. false);
  353. }
  354. constexpr auto
  355. amin(auto && a)
  356. {
  357. using std::min, std::numeric_limits;
  358. using T = ncvalue_t<decltype(a)>;
  359. T c = numeric_limits<T>::has_infinity ? numeric_limits<T>::infinity() : numeric_limits<T>::max();
  360. for_each([&c](auto && a) { if (a<c) { c = a; } }, a);
  361. return c;
  362. }
  363. constexpr auto
  364. amax(auto && a)
  365. {
  366. using std::max, std::numeric_limits;
  367. using T = ncvalue_t<decltype(a)>;
  368. T c = numeric_limits<T>::has_infinity ? -numeric_limits<T>::infinity() : numeric_limits<T>::lowest();
  369. for_each([&c](auto && a) { if (c<a) { c = a; } }, a);
  370. return c;
  371. }
  372. // FIXME encapsulate this kind of reference-reduction.
  373. // FIXME expr/ply mechanism doesn't allow partial iteration (adv then continue).
  374. template <class A, class Less = std::less<ncvalue_t<A>>>
  375. constexpr decltype(auto)
  376. refmin(A && a, Less && less = {})
  377. {
  378. RA_CHECK(a.size()>0, "refmin requires nonempty argument.");
  379. decltype(auto) s = ra::start(a);
  380. auto p = &(*s);
  381. for_each([&less, &p](auto & a) { if (less(a, *p)) { p = &a; } }, s);
  382. return *p;
  383. }
  384. template <class A, class Less = std::less<ncvalue_t<A>>>
  385. constexpr decltype(auto)
  386. refmax(A && a, Less && less = {})
  387. {
  388. RA_CHECK(a.size()>0, "refmax requires nonempty argument.");
  389. decltype(auto) s = ra::start(a);
  390. auto p = &(*s);
  391. for_each([&less, &p](auto & a) { if (less(*p, a)) { p = &a; } }, s);
  392. return *p;
  393. }
  394. constexpr auto
  395. sum(auto && a)
  396. {
  397. auto c = concrete_type<ncvalue_t<decltype(a)>>(0);
  398. for_each([&c](auto && a) { c += a; }, a);
  399. return c;
  400. }
  401. constexpr auto
  402. prod(auto && a)
  403. {
  404. auto c = concrete_type<ncvalue_t<decltype(a)>>(1);
  405. for_each([&c](auto && a) { c *= a; }, a);
  406. return c;
  407. }
  408. constexpr auto reduce_sqrm(auto && a) { return sum(sqrm(a)); }
  409. constexpr auto norm2(auto && a) { return std::sqrt(reduce_sqrm(a)); }
  410. #if 1==RA_DO_FMA
  411. constexpr void maybe_fma(auto && a, auto && b, auto & c) { c = fma(a, b, c); };
  412. constexpr void maybe_fma_conj(auto && a, auto && b, auto & c) { c = fma_conj(a, b, c); };
  413. #else
  414. constexpr void maybe_fma(auto && a, auto && b, auto & c) { c += a*b; };
  415. constexpr void maybe_fma_conj(auto && a, auto && b, auto & c) { c += conj(a)*b; };
  416. #endif
  417. constexpr auto
  418. dot(auto && a, auto && b)
  419. {
  420. std::decay_t<decltype(VALUE(a) * VALUE(b))> c(0.);
  421. for_each([&c](auto && a, auto && b) { maybe_fma(a, b, c); }, RA_FWD(a), RA_FWD(b));
  422. return c;
  423. }
  424. constexpr auto
  425. cdot(auto && a, auto && b)
  426. {
  427. std::decay_t<decltype(conj(VALUE(a)) * VALUE(b))> c(0.);
  428. for_each([&c](auto && a, auto && b) { maybe_fma_conj(a, b, c); }, RA_FWD(a), RA_FWD(b));
  429. return c;
  430. }
  431. constexpr auto
  432. normv(auto const & a)
  433. {
  434. auto b = concrete(a);
  435. b /= norm2(b);
  436. return b;
  437. }
  438. // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it (see bench-gemm.cc)
  439. constexpr void
  440. gemm(auto const & a, auto const & b, auto & c)
  441. {
  442. dim_t K=a.len(1);
  443. for (int k=0; k<K; ++k) {
  444. c += from(std::multiplies<>(), a(all, k), b(k)); // FIXME fma
  445. }
  446. }
  447. constexpr auto
  448. gemm(auto const & a, auto const & b)
  449. {
  450. dim_t M=a.len(0), N=b.len(1);
  451. using T = decltype(VALUE(a)*VALUE(b));
  452. using MMTYPE = decltype(from(std::multiplies<>(), a(all, 0), b(0)));
  453. auto c = with_shape<MMTYPE>({M, N}, T());
  454. gemm(a, b, c);
  455. return c;
  456. }
  457. constexpr auto
  458. gevm(auto const & a, auto const & b)
  459. {
  460. dim_t M=b.len(0), N=b.len(1);
  461. using T = decltype(VALUE(a)*VALUE(b));
  462. auto c = with_shape<decltype(a[0]*b(0))>({N}, T());
  463. for (int i=0; i<M; ++i) {
  464. maybe_fma(a[i], b(i), c);
  465. }
  466. return c;
  467. }
  468. // FIXME a must be a view, so it doesn't work with e.g. gemv(conj(a), b).
  469. constexpr auto
  470. gemv(auto const & a, auto const & b)
  471. {
  472. dim_t M=a.len(0), N=a.len(1);
  473. using T = decltype(VALUE(a)*VALUE(b));
  474. auto c = with_shape<decltype(a(all, 0)*b[0])>({M}, T());
  475. for (int j=0; j<N; ++j) {
  476. maybe_fma(a(all, j), b[j], c);
  477. }
  478. return c;
  479. }
  480. // --------------------
  481. // wedge product and cross product. FIXME seriously
  482. // --------------------
  483. namespace mp {
  484. template <class P, class Plist>
  485. struct FindCombination
  486. {
  487. template <class A> using match = bool_c<0 != PermutationSign<P, A>::value>;
  488. using type = IndexIf<Plist, match>;
  489. constexpr static int where = type::value;
  490. constexpr static int sign = (where>=0) ? PermutationSign<P, typename type::type>::value : 0;
  491. };
  492. // Combination antiC complementary to C wrt [0, 1, ... Dim-1], permuted so [C, antiC] has the sign of [0, 1, ... Dim-1].
  493. template <class C, int D>
  494. struct AntiCombination
  495. {
  496. using EC = complement<C, D>;
  497. static_assert((len<EC>)>=2, "can't correct this complement");
  498. constexpr static int sign = PermutationSign<append<C, EC>, iota<D>>::value;
  499. // Produce permutation of opposite sign if sign<0.
  500. using type = mp::cons<std::tuple_element_t<(sign<0) ? 1 : 0, EC>,
  501. mp::cons<std::tuple_element_t<(sign<0) ? 0 : 1, EC>,
  502. mp::drop<EC, 2>>>;
  503. };
  504. template <class C, int D> struct MapAntiCombination;
  505. template <int D, class ... C>
  506. struct MapAntiCombination<std::tuple<C ...>, D>
  507. {
  508. using type = std::tuple<typename AntiCombination<C, D>::type ...>;
  509. };
  510. template <int D, int O>
  511. struct ChooseComponents_
  512. {
  513. static_assert(D>=O, "Bad dimension or form order.");
  514. using type = mp::combinations<iota<D>, O>;
  515. };
  516. template <int D, int O> using ChooseComponents = typename ChooseComponents_<D, O>::type;
  517. template <int D, int O> requires ((D>1) && (2*O>D))
  518. struct ChooseComponents_<D, O>
  519. {
  520. static_assert(D>=O, "Bad dimension or form order.");
  521. using type = typename MapAntiCombination<ChooseComponents<D, D-O>, D>::type;
  522. };
  523. // Works almost to the range of std::size_t.
  524. constexpr std::size_t
  525. n_over_p(std::size_t const n, std::size_t p)
  526. {
  527. if (p>n) {
  528. return 0;
  529. } else if (p>(n-p)) {
  530. p = n-p;
  531. }
  532. std::size_t v = 1;
  533. for (std::size_t i=0; i!=p; ++i) {
  534. v = v*(n-i)/(i+1);
  535. }
  536. return v;
  537. }
  538. // We form the basis for the result (Cr) and split it in pieces for Oa and Ob; there are (D over Oa) ways. Then we see where and with which signs these pieces are in the bases for Oa (Ca) and Ob (Cb), and form the product.
  539. template <int D, int Oa, int Ob>
  540. struct Wedge
  541. {
  542. constexpr static int Or = Oa+Ob;
  543. static_assert(Oa<=D && Ob<=D && Or<=D, "bad orders");
  544. constexpr static int Na = n_over_p(D, Oa);
  545. constexpr static int Nb = n_over_p(D, Ob);
  546. constexpr static int Nr = n_over_p(D, Or);
  547. // in lexicographic order. Can be used to sort Ca below with FindPermutation.
  548. using LexOrCa = mp::combinations<mp::iota<D>, Oa>;
  549. // the actual components used, which are in lex. order only in some cases.
  550. using Ca = mp::ChooseComponents<D, Oa>;
  551. using Cb = mp::ChooseComponents<D, Ob>;
  552. using Cr = mp::ChooseComponents<D, Or>;
  553. // optimizations.
  554. constexpr static bool yields_expr = (Na>1) != (Nb>1);
  555. constexpr static bool yields_expr_a1 = yields_expr && Na==1;
  556. constexpr static bool yields_expr_b1 = yields_expr && Nb==1;
  557. constexpr static bool both_scalars = (Na==1 && Nb==1);
  558. constexpr static bool dot_plus = Na>1 && Nb>1 && Or==D && (Oa<Ob || (Oa>Ob && !ra::odd(Oa*Ob)));
  559. constexpr static bool dot_minus = Na>1 && Nb>1 && Or==D && (Oa>Ob && ra::odd(Oa*Ob));
  560. constexpr static bool general_case = (Na>1 && Nb>1) && ((Oa+Ob!=D) || (Oa==Ob));
  561. template <class Va, class Vb>
  562. using valtype = decltype(std::declval<Va>()[0] * std::declval<Vb>()[0]);
  563. template <class Xr, class Fa, class Va, class Vb>
  564. constexpr static valtype<Va, Vb>
  565. term(Va const & a, Vb const & b)
  566. {
  567. if constexpr (mp::len<Fa> > 0) {
  568. using Fa0 = mp::first<Fa>;
  569. using Fb = mp::complement_list<Fa0, Xr>;
  570. using Sa = mp::FindCombination<Fa0, Ca>;
  571. using Sb = mp::FindCombination<Fb, Cb>;
  572. constexpr int sign = Sa::sign * Sb::sign * mp::PermutationSign<mp::append<Fa0, Fb>, Xr>::value;
  573. static_assert(sign==+1 || sign==-1, "Bad sign in wedge term.");
  574. return valtype<Va, Vb>(sign)*a[Sa::where]*b[Sb::where] + term<Xr, mp::drop1<Fa>>(a, b);
  575. } else {
  576. return 0.;
  577. }
  578. }
  579. template <class Va, class Vb, class Vr, int wr>
  580. constexpr static void
  581. coeff(Va const & a, Vb const & b, Vr & r)
  582. {
  583. if constexpr (wr<Nr) {
  584. using Xr = mp::ref<Cr, wr>;
  585. using Fa = mp::combinations<Xr, Oa>;
  586. r[wr] = term<Xr, Fa>(a, b);
  587. coeff<Va, Vb, Vr, wr+1>(a, b, r);
  588. }
  589. }
  590. template <class Va, class Vb, class Vr>
  591. constexpr static void
  592. product(Va const & a, Vb const & b, Vr & r)
  593. {
  594. static_assert(Va::size()==Na, "Bad Va dim.");
  595. static_assert(Vb::size()==Nb, "Bad Vb dim.");
  596. static_assert(Vr::size()==Nr, "Bad Vr dim.");
  597. coeff<Va, Vb, Vr, 0>(a, b, r);
  598. }
  599. };
  600. // Euclidean space, only component shuffling.
  601. template <int D, int O>
  602. struct Hodge
  603. {
  604. using W = Wedge<D, O, D-O>;
  605. using Ca = typename W::Ca;
  606. using Cb = typename W::Cb;
  607. using Cr = typename W::Cr;
  608. using LexOrCa = typename W::LexOrCa;
  609. constexpr static int Na = W::Na;
  610. constexpr static int Nb = W::Nb;
  611. template <int i, class Va, class Vb>
  612. constexpr static void
  613. hodge_aux(Va const & a, Vb & b)
  614. {
  615. static_assert(i<=W::Na, "Bad argument to hodge_aux");
  616. if constexpr (i<W::Na) {
  617. using Cai = mp::ref<Ca, i>;
  618. static_assert(mp::len<Cai> == O, "Bad.");
  619. // sort Cai, because mp::complement only accepts sorted combinations.
  620. // ref<Cb, i> should be complementary to Cai, but I don't want to rely on that.
  621. using SCai = mp::ref<LexOrCa, mp::FindCombination<Cai, LexOrCa>::where>;
  622. using CompCai = mp::complement<SCai, D>;
  623. static_assert(mp::len<CompCai> == D-O, "Bad.");
  624. using fpw = mp::FindCombination<CompCai, Cb>;
  625. // for the sign see e.g. DoCarmo1991 I.Ex 10.
  626. using fps = mp::FindCombination<mp::append<Cai, mp::ref<Cb, fpw::where>>, Cr>;
  627. static_assert(fps::sign!=0, "Bad.");
  628. b[fpw::where] = decltype(a[i])(fps::sign)*a[i];
  629. hodge_aux<i+1>(a, b);
  630. }
  631. }
  632. };
  633. // The order of components is taken from Wedge<D, O, D-O>; this works for whatever order is defined there.
  634. // With lexicographic order, component order is reversed, but signs vary.
  635. // With the order given by ChooseComponents<>, fpw::where==i and fps::sign==+1 in hodge_aux(), always. Then hodge() becomes a free operation, (with one exception) and the next function hodge() can be used.
  636. template <int D, int O, class Va, class Vb>
  637. constexpr void
  638. hodgex(Va const & a, Vb & b)
  639. {
  640. static_assert(O<=D, "bad orders");
  641. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  642. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  643. mp::Hodge<D, O>::template hodge_aux<0>(a, b);
  644. }
  645. } // namespace ra::mp
  646. // This depends on Wedge<>::Ca, Cb, Cr coming from ChooseCombinations. hodgex() should always work, but this is cheaper.
  647. // However if 2*O=D, it is not possible to differentiate the bases by order and hodgex() must be used.
  648. // Likewise, when O(N-O) is odd, Hodge from (2*O>D) to (2*O<D) change sign, since **w= -w in that case, and the basis in the (2*O>D) case is selected to make Hodge(<)->Hodge(>) trivial; but can't do both!
  649. consteval bool trivial_hodge(int D, int O) { return 2*O!=D && ((2*O<D) || !ra::odd(O*(D-O))); }
  650. template <int D, int O, class Va, class Vb>
  651. constexpr void
  652. hodge(Va const & a, Vb & b)
  653. {
  654. if constexpr (trivial_hodge(D, O)) {
  655. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  656. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  657. b = a;
  658. } else {
  659. ra::mp::hodgex<D, O>(a, b);
  660. }
  661. }
  662. template <int D, int O, class Va> requires (trivial_hodge(D, O))
  663. constexpr Va const &
  664. hodge(Va const & a)
  665. {
  666. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  667. return a;
  668. }
  669. template <int D, int O, class Va> requires (!trivial_hodge(D, O))
  670. constexpr Va &
  671. hodge(Va & a)
  672. {
  673. Va b(a);
  674. ra::mp::hodgex<D, O>(b, a);
  675. return a;
  676. }
  677. template <int D, int Oa, int Ob, class A, class B> requires (is_scalar<A> && is_scalar<B>)
  678. constexpr auto
  679. wedge(A const & a, B const & b) { return a*b; }
  680. template <class A>
  681. using torank1 = std::conditional_t<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
  682. template <int D, int Oa, int Ob, class Va, class Vb> requires (!(is_scalar<Va> && is_scalar<Vb>))
  683. decltype(auto)
  684. wedge(Va const & a, Vb const & b)
  685. {
  686. Small<ncvalue_t<Va>, size_s<Va>()> aa = a;
  687. Small<ncvalue_t<Vb>, size_s<Vb>()> bb = b;
  688. using Ua = decltype(aa);
  689. using Ub = decltype(bb);
  690. using Wedge = mp::Wedge<D, Oa, Ob>;
  691. using valtype = typename Wedge::template valtype<Ua, Ub>;
  692. std::conditional_t<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr>> r;
  693. auto & a1 = reinterpret_cast<torank1<Ua> const &>(aa);
  694. auto & b1 = reinterpret_cast<torank1<Ub> const &>(bb);
  695. auto & r1 = reinterpret_cast<torank1<decltype(r)> &>(r);
  696. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  697. return r;
  698. }
  699. template <int D, int Oa, int Ob, class Va, class Vb, class Vr> requires (!(is_scalar<Va> && is_scalar<Vb>))
  700. void
  701. wedge(Va const & a, Vb const & b, Vr & r)
  702. {
  703. Small<ncvalue_t<Va>, size_s<Va>()> aa = a;
  704. Small<ncvalue_t<Vb>, size_s<Vb>()> bb = b;
  705. using Ua = decltype(aa);
  706. using Ub = decltype(bb);
  707. auto & r1 = reinterpret_cast<torank1<decltype(r)> &>(r);
  708. auto & a1 = reinterpret_cast<torank1<Ua> const &>(aa);
  709. auto & b1 = reinterpret_cast<torank1<Ub> const &>(bb);
  710. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  711. }
  712. template <class A, class B>
  713. constexpr auto
  714. cross(A const & a_, B const & b_)
  715. {
  716. constexpr int n = size_s<A>();
  717. static_assert(n==size_s<B>() && (2==n || 3==n));
  718. Small<std::decay_t<decltype(VALUE(a_))>, n> a = a_;
  719. Small<std::decay_t<decltype(VALUE(b_))>, n> b = b_;
  720. using W = mp::Wedge<n, 1, 1>;
  721. Small<std::decay_t<decltype(VALUE(a_) * VALUE(b_))>, W::Nr> r;
  722. W::product(a, b, r);
  723. if constexpr (1==W::Nr) {
  724. return r[0];
  725. } else {
  726. return r;
  727. }
  728. }
  729. template <class V>
  730. constexpr auto
  731. perp(V const & v)
  732. {
  733. static_assert(2==v.size(), "Dimension error.");
  734. return Small<std::decay_t<decltype(VALUE(v))>, 2> {v[1], -v[0]};
  735. }
  736. template <class V, class U>
  737. constexpr auto
  738. perp(V const & v, U const & n)
  739. {
  740. if constexpr (is_scalar<U>) {
  741. static_assert(2==v.size(), "Dimension error.");
  742. return Small<std::decay_t<decltype(VALUE(v) * n)>, 2> {v[1]*n, -v[0]*n};
  743. } else {
  744. static_assert(3==v.size(), "Dimension error.");
  745. return cross(v, n);
  746. }
  747. }
  748. } // namespace ra
  749. #undef RA_OPT