ra.hh 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Operator overloads for expression templates, and root header.
  3. // (c) Daniel Llorens - 2014-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "big.hh"
  10. #include "optimize.hh"
  11. #include "complex.hh"
  12. #ifndef RA_DO_OPT
  13. #define RA_DO_OPT 1 // enabled by default
  14. #endif
  15. #if RA_DO_OPT==1
  16. #define RA_OPT optimize
  17. #else
  18. #define RA_OPT
  19. #endif
  20. // Enable ADL with explicit template args. See http://stackoverflow.com/questions/9838862.
  21. template <class A> constexpr void transpose(ra::noarg);
  22. template <int A> constexpr void iter(ra::noarg);
  23. namespace ra {
  24. template <class T> constexpr bool is_scalar_def<std::complex<T>> = true;
  25. template <int ... Iarg, class A>
  26. constexpr decltype(auto)
  27. transpose(mp::int_list<Iarg ...>, A && a)
  28. {
  29. return transpose<Iarg ...>(RA_FWD(a));
  30. }
  31. constexpr bool odd(unsigned int N) { return N & 1; }
  32. // ---------------------------
  33. // TODO integrate with beatable<> shortcuts, operator() in the various array types.
  34. // ---------------------------
  35. template <class II, int drop, class Op>
  36. constexpr decltype(auto)
  37. from_partial(Op && op)
  38. {
  39. if constexpr (drop==mp::len<II>) {
  40. return RA_FWD(op);
  41. } else {
  42. return wrank(mp::append<mp::makelist<drop, ic_t<0>>, mp::drop<II, drop>> {},
  43. from_partial<II, drop+1>(RA_FWD(op)));
  44. }
  45. }
  46. // TODO should be able to do better by slicing at each dimension, etc. But verb<>'s innermost op must be rank 0.
  47. template <class A, class ... I>
  48. constexpr decltype(auto)
  49. from(A && a, I && ... i)
  50. {
  51. if constexpr (0==sizeof...(i)) {
  52. return RA_FWD(a)();
  53. } else if constexpr (1==sizeof...(i)) {
  54. // support dynamic rank for 1 arg only (see test in test/from.cc).
  55. return map(RA_FWD(a), RA_FWD(i) ...);
  56. } else {
  57. return map(from_partial<mp::tuple<ic_t<rank_s<I>()> ...>, 1>(RA_FWD(a)), RA_FWD(i) ...);
  58. }
  59. }
  60. // --------------------------------
  61. // Array versions of operators and functions
  62. // --------------------------------
  63. // We need zero/scalar specializations because the scalar/scalar operators maybe be templated (e.g. complex<>), so they won't be found when an implicit conversion from zero->scalar is also needed. That is, without those specializations, ra::View<complex, 0> * complex will fail.
  64. // The function objects are matched in optimize.hh.
  65. #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
  66. template <class A, class B> requires (tomap<A, B>) constexpr auto \
  67. operator OP(A && a, B && b) \
  68. { return RA_OPT(map(OPNAME(), RA_FWD(a), RA_FWD(b))); } \
  69. template <class A, class B> requires (toreduce<A, B>) constexpr auto \
  70. operator OP(A && a, B && b) \
  71. { return FLAT(RA_FWD(a)) OP FLAT(RA_FWD(b)); }
  72. DEF_NAMED_BINARY_OP(+, std::plus<>) DEF_NAMED_BINARY_OP(-, std::minus<>)
  73. DEF_NAMED_BINARY_OP(*, std::multiplies<>) DEF_NAMED_BINARY_OP(/, std::divides<>)
  74. DEF_NAMED_BINARY_OP(==, std::equal_to<>) DEF_NAMED_BINARY_OP(>, std::greater<>)
  75. DEF_NAMED_BINARY_OP(<, std::less<>) DEF_NAMED_BINARY_OP(>=, std::greater_equal<>)
  76. DEF_NAMED_BINARY_OP(<=, std::less_equal<>) DEF_NAMED_BINARY_OP(!=, std::not_equal_to<>)
  77. DEF_NAMED_BINARY_OP(|, std::bit_or<>) DEF_NAMED_BINARY_OP(&, std::bit_and<>)
  78. DEF_NAMED_BINARY_OP(^, std::bit_xor<>) DEF_NAMED_BINARY_OP(<=>, std::compare_three_way)
  79. #undef DEF_NAMED_BINARY_OP
  80. // FIXME address sanitizer complains in bench-optimize.cc if we use std::identity. Maybe false positive
  81. struct unaryplus
  82. {
  83. template <class T> constexpr /* static P1169 in gcc13 */ auto
  84. operator()(T && t) const noexcept { return RA_FWD(t); }
  85. };
  86. #define DEF_NAMED_UNARY_OP(OP, OPNAME) \
  87. template <class A> requires (tomap<A>) constexpr auto \
  88. operator OP(A && a) \
  89. { return map(OPNAME(), RA_FWD(a)); } \
  90. template <class A> requires (toreduce<A>) constexpr auto \
  91. operator OP(A && a) \
  92. { return OP FLAT(RA_FWD(a)); }
  93. DEF_NAMED_UNARY_OP(+, unaryplus)
  94. DEF_NAMED_UNARY_OP(-, std::negate<>)
  95. DEF_NAMED_UNARY_OP(!, std::logical_not<>)
  96. #undef DEF_NAMED_UNARY_OP
  97. // if OP(a) isn't found in ra::, deduction rank(0) -> scalar doesn't work. TODO Cf useret.cc, reexported.cc
  98. #define DEF_NAME(OP) \
  99. template <class ... A> requires (tomap<A ...>) constexpr auto \
  100. OP(A && ... a) \
  101. { return map([](auto && ... a) -> decltype(auto) { return OP(RA_FWD(a) ...); }, RA_FWD(a) ...); } \
  102. template <class ... A> requires (toreduce<A ...>) constexpr decltype(auto) \
  103. OP(A && ... a) \
  104. { return OP(FLAT(RA_FWD(a)) ...); }
  105. #define DEF_FWD(QUALIFIED_OP, OP) \
  106. template <class ... A> requires (!tomap<A ...> && !toreduce<A ...>) constexpr decltype(auto) \
  107. OP(A && ... a) \
  108. { return QUALIFIED_OP(RA_FWD(a) ...); } \
  109. DEF_NAME(OP)
  110. #define DEF_USING(QUALIFIED_OP, OP) \
  111. using QUALIFIED_OP; \
  112. DEF_NAME(OP)
  113. // FIXME move rel_error etc. out of :: and in here. Maybe do _FWD just for std:: types?
  114. FOR_EACH(DEF_NAME, odd, arg, sqr, sqrm, real_part, imag_part, xI, rel_error)
  115. // can't DEF_USING bc std::max will gobble ra:: objects if passed by const & (!)
  116. // FIXME define own global max/min overloads for basic types. std::max seems too much of a special case to be usinged.
  117. #define DEF_GLOBAL(f) DEF_FWD(::f, f)
  118. FOR_EACH(DEF_GLOBAL, max, min)
  119. #undef DEF_GLOBAL
  120. // don't use DEF_FWD for these bc we want to allow ADL, e.g. for exp(dual).
  121. #define DEF_GLOBAL(f) DEF_USING(::f, f)
  122. FOR_EACH(DEF_GLOBAL, pow, conj, sqrt, exp, expm1, log, log1p, log10, isfinite, isnan, isinf, atan2)
  123. FOR_EACH(DEF_GLOBAL, abs, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, clamp, lerp)
  124. #undef DEF_GLOBAL
  125. #undef DEF_USING
  126. #undef DEF_FWD
  127. #undef DEF_NAME
  128. template <class T, class A>
  129. constexpr auto
  130. cast(A && a)
  131. {
  132. return map([](auto && b) -> decltype(auto) { return T(b); }, RA_FWD(a));
  133. }
  134. // TODO std::forward_as_tuple?
  135. template <class T, class ... A>
  136. constexpr auto
  137. pack(A && ... a)
  138. {
  139. return map([](auto && ... a) { return T { a ... }; }, RA_FWD(a) ...);
  140. }
  141. // FIXME needs a nested array for I
  142. template <class A, class I>
  143. constexpr auto
  144. at(A && a, I && i)
  145. {
  146. return map([a = std::tuple<A>(RA_FWD(a))] (auto && i) -> decltype(auto) { return std::get<0>(a).at(i); },
  147. RA_FWD(i));
  148. }
  149. // --------------------------------
  150. // selection / shortcutting
  151. // --------------------------------
  152. // ra::start are needed bc rank 0 converts to and from scalar, so ? can't pick the right (-> scalar) conversion.
  153. template <class T, class F> requires (toreduce<T, F>)
  154. constexpr decltype(auto)
  155. where(bool const w, T && t, F && f)
  156. {
  157. return w ? FLAT(t) : FLAT(f);
  158. }
  159. template <class W, class T, class F> requires (tomap<W, T, F>)
  160. constexpr auto
  161. where(W && w, T && t, F && f)
  162. {
  163. return pick(cast<bool>(RA_FWD(w)), RA_FWD(f), RA_FWD(t));
  164. }
  165. // catch all for non-ra types.
  166. template <class T, class F> requires (!(tomap<T, F>) && !(toreduce<T, F>))
  167. constexpr decltype(auto)
  168. where(bool const w, T && t, F && f)
  169. {
  170. return w ? t : f;
  171. }
  172. template <class A, class B> requires (tomap<A, B>)
  173. constexpr auto
  174. operator &&(A && a, B && b)
  175. {
  176. return where(RA_FWD(a), cast<bool>(RA_FWD(b)), false);
  177. }
  178. template <class A, class B> requires (tomap<A, B>)
  179. constexpr auto
  180. operator ||(A && a, B && b)
  181. {
  182. return where(RA_FWD(a), true, cast<bool>(RA_FWD(b)));
  183. }
  184. #define DEF_SHORTCIRCUIT_BINARY_OP(OP) \
  185. template <class A, class B> requires (toreduce<A, B>) \
  186. constexpr auto operator OP(A && a, B && b) \
  187. { \
  188. return FLAT(a) OP FLAT(b); \
  189. }
  190. FOR_EACH(DEF_SHORTCIRCUIT_BINARY_OP, &&, ||);
  191. #undef DEF_SHORTCIRCUIT_BINARY_OP
  192. // --------------------------------
  193. // Some whole-array reductions.
  194. // TODO First rank reductions? Variable rank reductions?
  195. // FIXME C++23 and_then/or_else/etc
  196. // --------------------------------
  197. template <class A>
  198. constexpr bool
  199. any(A && a)
  200. {
  201. return early(map([](bool x) { return x ? std::make_optional(true) : std::nullopt; }, RA_FWD(a)), false);
  202. }
  203. template <class A>
  204. constexpr bool
  205. every(A && a)
  206. {
  207. return early(map([](bool x) { return !x ? std::make_optional(false) : std::nullopt; }, RA_FWD(a)), true);
  208. }
  209. // FIXME variable rank? see J 'index of' (x i. y), etc.
  210. template <class A>
  211. constexpr auto
  212. index(A && a)
  213. {
  214. return early(map([](auto && a, auto && i) { return bool(a) ? std::make_optional(i) : std::nullopt; },
  215. RA_FWD(a), ra::iota(ra::start(a).len(0))),
  216. ra::dim_t(-1));
  217. }
  218. // [ma108]
  219. template <class A, class B>
  220. constexpr bool
  221. lexicographical_compare(A && a, B && b)
  222. {
  223. return early(map([](auto && a, auto && b) { return a==b ? std::nullopt : std::make_optional(a<b); },
  224. RA_FWD(a), RA_FWD(b)),
  225. false);
  226. }
  227. template <class A>
  228. constexpr auto
  229. amin(A && a)
  230. {
  231. using std::min;
  232. using T = value_t<A>;
  233. T c = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max();
  234. for_each([&c](auto && a) { if (a<c) { c = a; } }, a);
  235. return c;
  236. }
  237. template <class A>
  238. constexpr auto
  239. amax(A && a)
  240. {
  241. using std::max;
  242. using T = value_t<A>;
  243. T c = std::numeric_limits<T>::has_infinity ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::lowest();
  244. for_each([&c](auto && a) { if (c<a) { c = a; } }, a);
  245. return c;
  246. }
  247. // FIXME encapsulate this kind of reference-reduction.
  248. // FIXME expr/ply mechanism doesn't allow partial iteration (adv then continue).
  249. template <class A, class Less = std::less<value_t<A>>>
  250. constexpr decltype(auto)
  251. refmin(A && a, Less && less = std::less<value_t<A>>())
  252. {
  253. RA_CHECK(a.size()>0);
  254. decltype(auto) s = ra::start(a);
  255. auto p = &(*s);
  256. for_each([&less, &p](auto & a) { if (less(a, *p)) { p = &a; } }, s);
  257. return *p;
  258. }
  259. template <class A, class Less = std::less<value_t<A>>>
  260. constexpr decltype(auto)
  261. refmax(A && a, Less && less = std::less<value_t<A>>())
  262. {
  263. RA_CHECK(a.size()>0);
  264. decltype(auto) s = ra::start(a);
  265. auto p = &(*s);
  266. for_each([&less, &p](auto & a) { if (less(*p, a)) { p = &a; } }, s);
  267. return *p;
  268. }
  269. template <class A>
  270. constexpr auto
  271. sum(A && a)
  272. {
  273. auto c = concrete_type<value_t<A>>(0);
  274. for_each([&c](auto && a) { c += a; }, a);
  275. return c;
  276. }
  277. template <class A>
  278. constexpr auto
  279. prod(A && a)
  280. {
  281. auto c = concrete_type<value_t<A>>(1);
  282. for_each([&c](auto && a) { c *= a; }, a);
  283. return c;
  284. }
  285. template <class A> constexpr auto reduce_sqrm(A && a) { return sum(sqrm(a)); }
  286. template <class A> constexpr auto norm2(A && a) { return std::sqrt(reduce_sqrm(a)); }
  287. template <class A, class B>
  288. constexpr auto
  289. dot(A && a, B && b)
  290. {
  291. std::decay_t<decltype(FLAT(a) * FLAT(b))> c(0.);
  292. for_each([&c](auto && a, auto && b)
  293. {
  294. #ifdef FP_FAST_FMA
  295. c = fma(a, b, c);
  296. #else
  297. c += a*b;
  298. #endif
  299. }, a, b);
  300. return c;
  301. }
  302. template <class A, class B>
  303. constexpr auto
  304. cdot(A && a, B && b)
  305. {
  306. std::decay_t<decltype(conj(FLAT(a)) * FLAT(b))> c(0.);
  307. for_each([&c](auto && a, auto && b)
  308. {
  309. #ifdef FP_FAST_FMA
  310. c = fma_conj(a, b, c);
  311. #else
  312. c += conj(a)*b;
  313. #endif
  314. }, a, b);
  315. return c;
  316. }
  317. // --------------------
  318. // Other whole-array ops.
  319. // --------------------
  320. template <class A>
  321. constexpr auto
  322. normv(A const & a)
  323. {
  324. auto b = concrete(a);
  325. b /= norm2(b);
  326. return b;
  327. }
  328. // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it.
  329. template <class A, class B, class C>
  330. constexpr void
  331. gemm(A const & a, B const & b, C & c)
  332. {
  333. for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })), c, a, b);
  334. }
  335. #define MMTYPE decltype(from(std::multiplies<>(), a(all, 0), b(0)))
  336. // default for row-major x row-major. See bench-gemm.cc for variants.
  337. template <class S, class T>
  338. constexpr auto
  339. gemm(ra::View<S, 2> const & a, ra::View<T, 2> const & b)
  340. {
  341. dim_t M=a.len(0), N=b.len(1), K=a.len(1);
  342. // no with_same_shape bc cannot index 0 for type if A/B are empty
  343. auto c = with_shape<MMTYPE>({M, N}, decltype(std::declval<S>()*std::declval<T>())());
  344. for (int k=0; k<K; ++k) {
  345. c += from(std::multiplies<>(), a(all, k), b(k));
  346. }
  347. return c;
  348. }
  349. // we still want the Small version to be different.
  350. template <class A, class B>
  351. constexpr ra::Small<std::decay_t<decltype(FLAT(std::declval<A>()) * FLAT(std::declval<B>()))>, A::len(0), B::len(1)>
  352. gemm(A const & a, B const & b)
  353. {
  354. dim_t M=a.len(0), N=b.len(1);
  355. // no with_same_shape bc cannot index 0 for type if A/B are empty
  356. auto c = with_shape<MMTYPE>({M, N}, ra::none);
  357. for (int i=0; i<M; ++i) {
  358. for (int j=0; j<N; ++j) {
  359. c(i, j) = dot(a(i), b(all, j));
  360. }
  361. }
  362. return c;
  363. }
  364. #undef MMTYPE
  365. template <class A, class B>
  366. constexpr auto
  367. gevm(A const & a, B const & b)
  368. {
  369. dim_t M=b.len(0), N=b.len(1);
  370. // no with_same_shape bc cannot index 0 for type if A/B are empty
  371. auto c = with_shape<decltype(a[0]*b(0))>({N}, 0);
  372. for (int i=0; i<M; ++i) {
  373. c += a[i]*b(i);
  374. }
  375. return c;
  376. }
  377. // FIXME a must be a view, so it doesn't work with e.g. gemv(conj(a), b).
  378. template <class A, class B>
  379. constexpr auto
  380. gemv(A const & a, B const & b)
  381. {
  382. dim_t M=a.len(0), N=a.len(1);
  383. // no with_same_shape bc cannot index 0 for type if A/B are empty
  384. auto c = with_shape<decltype(a(all, 0)*b[0])>({M}, 0);
  385. for (int j=0; j<N; ++j) {
  386. c += a(all, j) * b[j];
  387. }
  388. return c;
  389. }
  390. // --------------------
  391. // Wedge product and cross product
  392. // --------------------
  393. namespace mp {
  394. template <class P, class Plist>
  395. struct FindCombination
  396. {
  397. template <class A> using match = bool_c<0 != PermutationSign<P, A>::value>;
  398. using type = IndexIf<Plist, match>;
  399. constexpr static int where = type::value;
  400. constexpr static int sign = (where>=0) ? PermutationSign<P, typename type::type>::value : 0;
  401. };
  402. // Combination antiC complementary to C wrt [0, 1, ... Dim-1], permuted so [C, antiC] has the sign of [0, 1, ... Dim-1].
  403. template <class C, int D>
  404. struct AntiCombination
  405. {
  406. using EC = complement<C, D>;
  407. static_assert((len<EC>)>=2, "can't correct this complement");
  408. constexpr static int sign = PermutationSign<append<C, EC>, iota<D>>::value;
  409. // Produce permutation of opposite sign if sign<0.
  410. using type = mp::cons<std::tuple_element_t<(sign<0) ? 1 : 0, EC>,
  411. mp::cons<std::tuple_element_t<(sign<0) ? 0 : 1, EC>,
  412. mp::drop<EC, 2>>>;
  413. };
  414. template <class C, int D> struct MapAntiCombination;
  415. template <int D, class ... C>
  416. struct MapAntiCombination<std::tuple<C ...>, D>
  417. {
  418. using type = std::tuple<typename AntiCombination<C, D>::type ...>;
  419. };
  420. template <int D, int O>
  421. struct ChooseComponents
  422. {
  423. static_assert(D>=O, "Bad dimension or form order.");
  424. using type = mp::combinations<iota<D>, O>;
  425. };
  426. template <int D, int O> using ChooseComponents_ = typename ChooseComponents<D, O>::type;
  427. template <int D, int O> requires ((D>1) && (2*O>D))
  428. struct ChooseComponents<D, O>
  429. {
  430. static_assert(D>=O, "Bad dimension or form order.");
  431. using type = typename MapAntiCombination<ChooseComponents_<D, D-O>, D>::type;
  432. };
  433. // Works almost to the range of std::size_t.
  434. constexpr std::size_t
  435. n_over_p(std::size_t const n, std::size_t p)
  436. {
  437. if (p>n) {
  438. return 0;
  439. } else if (p>(n-p)) {
  440. p = n-p;
  441. }
  442. std::size_t v = 1;
  443. for (std::size_t i=0; i!=p; ++i) {
  444. v = v*(n-i)/(i+1);
  445. }
  446. return v;
  447. }
  448. // We form the basis for the result (Cr) and split it in pieces for Oa and Ob; there are (D over Oa) ways. Then we see where and with which signs these pieces are in the bases for Oa (Ca) and Ob (Cb), and form the product.
  449. template <int D, int Oa, int Ob>
  450. struct Wedge
  451. {
  452. constexpr static int Or = Oa+Ob;
  453. static_assert(Oa<=D && Ob<=D && Or<=D, "bad orders");
  454. constexpr static int Na = n_over_p(D, Oa);
  455. constexpr static int Nb = n_over_p(D, Ob);
  456. constexpr static int Nr = n_over_p(D, Or);
  457. // in lexicographic order. Can be used to sort Ca below with FindPermutation.
  458. using LexOrCa = mp::combinations<mp::iota<D>, Oa>;
  459. // the actual components used, which are in lex. order only in some cases.
  460. using Ca = mp::ChooseComponents_<D, Oa>;
  461. using Cb = mp::ChooseComponents_<D, Ob>;
  462. using Cr = mp::ChooseComponents_<D, Or>;
  463. // optimizations.
  464. constexpr static bool yields_expr = (Na>1) != (Nb>1);
  465. constexpr static bool yields_expr_a1 = yields_expr && Na==1;
  466. constexpr static bool yields_expr_b1 = yields_expr && Nb==1;
  467. constexpr static bool both_scalars = (Na==1 && Nb==1);
  468. constexpr static bool dot_plus = Na>1 && Nb>1 && Or==D && (Oa<Ob || (Oa>Ob && !ra::odd(Oa*Ob)));
  469. constexpr static bool dot_minus = Na>1 && Nb>1 && Or==D && (Oa>Ob && ra::odd(Oa*Ob));
  470. constexpr static bool general_case = (Na>1 && Nb>1) && ((Oa+Ob!=D) || (Oa==Ob));
  471. template <class Va, class Vb>
  472. using valtype = std::decay_t<decltype(std::declval<Va>()[0] * std::declval<Vb>()[0])>;
  473. template <class Xr, class Fa, class Va, class Vb>
  474. constexpr static valtype<Va, Vb>
  475. term(Va const & a, Vb const & b)
  476. {
  477. if constexpr (mp::len<Fa> > 0) {
  478. using Fa0 = mp::first<Fa>;
  479. using Fb = mp::complement_list<Fa0, Xr>;
  480. using Sa = mp::FindCombination<Fa0, Ca>;
  481. using Sb = mp::FindCombination<Fb, Cb>;
  482. constexpr int sign = Sa::sign * Sb::sign * mp::PermutationSign<mp::append<Fa0, Fb>, Xr>::value;
  483. static_assert(sign==+1 || sign==-1, "Bad sign in wedge term.");
  484. return valtype<Va, Vb>(sign)*a[Sa::where]*b[Sb::where] + term<Xr, mp::drop1<Fa>>(a, b);
  485. } else {
  486. return 0.;
  487. }
  488. }
  489. template <class Va, class Vb, class Vr, int wr>
  490. constexpr static void
  491. coeff(Va const & a, Vb const & b, Vr & r)
  492. {
  493. if constexpr (wr<Nr) {
  494. using Xr = mp::ref<Cr, wr>;
  495. using Fa = mp::combinations<Xr, Oa>;
  496. r[wr] = term<Xr, Fa>(a, b);
  497. coeff<Va, Vb, Vr, wr+1>(a, b, r);
  498. }
  499. }
  500. template <class Va, class Vb, class Vr>
  501. constexpr static void
  502. product(Va const & a, Vb const & b, Vr & r)
  503. {
  504. static_assert(Va::size()==Na, "Bad Va dim.");
  505. static_assert(Vb::size()==Nb, "Bad Vb dim.");
  506. static_assert(Vr::size()==Nr, "Bad Vr dim.");
  507. coeff<Va, Vb, Vr, 0>(a, b, r);
  508. }
  509. };
  510. // Euclidean space, only component shuffling.
  511. template <int D, int O>
  512. struct Hodge
  513. {
  514. using W = Wedge<D, O, D-O>;
  515. using Ca = typename W::Ca;
  516. using Cb = typename W::Cb;
  517. using Cr = typename W::Cr;
  518. using LexOrCa = typename W::LexOrCa;
  519. constexpr static int Na = W::Na;
  520. constexpr static int Nb = W::Nb;
  521. template <int i, class Va, class Vb>
  522. constexpr static void
  523. hodge_aux(Va const & a, Vb & b)
  524. {
  525. static_assert(i<=W::Na, "Bad argument to hodge_aux");
  526. if constexpr (i<W::Na) {
  527. using Cai = mp::ref<Ca, i>;
  528. static_assert(mp::len<Cai> == O, "Bad.");
  529. // sort Cai, because mp::complement only accepts sorted combinations.
  530. // ref<Cb, i> should be complementary to Cai, but I don't want to rely on that.
  531. using SCai = mp::ref<LexOrCa, mp::FindCombination<Cai, LexOrCa>::where>;
  532. using CompCai = mp::complement<SCai, D>;
  533. static_assert(mp::len<CompCai> == D-O, "Bad.");
  534. using fpw = mp::FindCombination<CompCai, Cb>;
  535. // for the sign see e.g. DoCarmo1991 I.Ex 10.
  536. using fps = mp::FindCombination<mp::append<Cai, mp::ref<Cb, fpw::where>>, Cr>;
  537. static_assert(fps::sign!=0, "Bad.");
  538. b[fpw::where] = decltype(a[i])(fps::sign)*a[i];
  539. hodge_aux<i+1>(a, b);
  540. }
  541. }
  542. };
  543. // The order of components is taken from Wedge<D, O, D-O>; this works for whatever order is defined there.
  544. // With lexicographic order, component order is reversed, but signs vary.
  545. // With the order given by ChooseComponents<>, fpw::where==i and fps::sign==+1 in hodge_aux(), always. Then hodge() becomes a free operation, (with one exception) and the next function hodge() can be used.
  546. template <int D, int O, class Va, class Vb>
  547. constexpr void
  548. hodgex(Va const & a, Vb & b)
  549. {
  550. static_assert(O<=D, "bad orders");
  551. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  552. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  553. mp::Hodge<D, O>::template hodge_aux<0>(a, b);
  554. }
  555. } // namespace ra::mp
  556. // This depends on Wedge<>::Ca, Cb, Cr coming from ChooseCombinations. hodgex() should always work, but this is cheaper.
  557. // However if 2*O=D, it is not possible to differentiate the bases by order and hodgex() must be used.
  558. // Likewise, when O(N-O) is odd, Hodge from (2*O>D) to (2*O<D) change sign, since **w= -w in that case, and the basis in the (2*O>D) case is selected to make Hodge(<)->Hodge(>) trivial; but can't do both!
  559. consteval bool trivial_hodge(int D, int O) { return 2*O!=D && ((2*O<D) || !ra::odd(O*(D-O))); }
  560. template <int D, int O, class Va, class Vb>
  561. constexpr void
  562. hodge(Va const & a, Vb & b)
  563. {
  564. if constexpr (trivial_hodge(D, O)) {
  565. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  566. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  567. b = a;
  568. } else {
  569. ra::mp::hodgex<D, O>(a, b);
  570. }
  571. }
  572. template <int D, int O, class Va> requires (trivial_hodge(D, O))
  573. constexpr Va const &
  574. hodge(Va const & a)
  575. {
  576. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  577. return a;
  578. }
  579. template <int D, int O, class Va> requires (!trivial_hodge(D, O))
  580. constexpr Va &
  581. hodge(Va & a)
  582. {
  583. Va b(a);
  584. ra::mp::hodgex<D, O>(b, a);
  585. return a;
  586. }
  587. // --------------------
  588. // Wedge product
  589. // --------------------
  590. template <int D, int Oa, int Ob, class A, class B> requires (ra::is_scalar<A> && ra::is_scalar<B>)
  591. constexpr auto
  592. wedge(A const & a, B const & b) { return a*b; }
  593. template <class A>
  594. using torank1 = std::conditional_t<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
  595. template <int D, int Oa, int Ob, class Va, class Vb> requires (!(is_scalar<Va> && is_scalar<Vb>))
  596. decltype(auto)
  597. wedge(Va const & a, Vb const & b)
  598. {
  599. Small<value_t<Va>, size_s<Va>()> aa = a;
  600. Small<value_t<Vb>, size_s<Vb>()> bb = b;
  601. using Ua = decltype(aa);
  602. using Ub = decltype(bb);
  603. using Wedge = mp::Wedge<D, Oa, Ob>;
  604. using valtype = typename Wedge::template valtype<Ua, Ub>;
  605. std::conditional_t<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr>> r;
  606. auto & a1 = reinterpret_cast<torank1<Ua> const &>(aa);
  607. auto & b1 = reinterpret_cast<torank1<Ub> const &>(bb);
  608. auto & r1 = reinterpret_cast<torank1<decltype(r)> &>(r);
  609. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  610. return r;
  611. }
  612. template <int D, int Oa, int Ob, class Va, class Vb, class Vr> requires (!(is_scalar<Va> && is_scalar<Vb>))
  613. void
  614. wedge(Va const & a, Vb const & b, Vr & r)
  615. {
  616. Small<value_t<Va>, size_s<Va>()> aa = a;
  617. Small<value_t<Vb>, size_s<Vb>()> bb = b;
  618. using Ua = decltype(aa);
  619. using Ub = decltype(bb);
  620. auto & r1 = reinterpret_cast<torank1<decltype(r)> &>(r);
  621. auto & a1 = reinterpret_cast<torank1<Ua> const &>(aa);
  622. auto & b1 = reinterpret_cast<torank1<Ub> const &>(bb);
  623. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  624. }
  625. template <class A, class B>
  626. constexpr auto
  627. cross(A const & a_, B const & b_)
  628. {
  629. constexpr int n = size_s<A>();
  630. static_assert(n==size_s<B>() && (2==n || 3==n));
  631. Small<std::decay_t<decltype(FLAT(a_))>, n> a = a_;
  632. Small<std::decay_t<decltype(FLAT(b_))>, n> b = b_;
  633. using W = mp::Wedge<n, 1, 1>;
  634. Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, W::Nr> r;
  635. W::product(a, b, r);
  636. if constexpr (1==W::Nr) {
  637. return r[0];
  638. } else {
  639. return r;
  640. }
  641. }
  642. template <class V>
  643. constexpr auto
  644. perp(V const & v)
  645. {
  646. static_assert(2==v.size(), "Dimension error.");
  647. return Small<std::decay_t<decltype(FLAT(v))>, 2> {v[1], -v[0]};
  648. }
  649. template <class V, class U>
  650. constexpr auto
  651. perp(V const & v, U const & n)
  652. {
  653. if constexpr (is_scalar<U>) {
  654. static_assert(2==v.size(), "Dimension error.");
  655. return Small<std::decay_t<decltype(FLAT(v) * n)>, 2> {v[1]*n, -v[0]*n};
  656. } else {
  657. static_assert(3==v.size(), "Dimension error.");
  658. return cross(v, n);
  659. }
  660. }
  661. } // namespace ra
  662. #undef RA_OPT