wedge.hh 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Wedge product and cross product.
  3. // (c) Daniel Llorens - 2008-2022
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "small.hh"
  10. namespace ra::mp {
  11. template <class P, class Plist>
  12. struct FindCombination
  13. {
  14. template <class A> using match = bool_c<0 != PermutationSign<P, A>::value>;
  15. using type = IndexIf<Plist, match>;
  16. constexpr static int where = type::value;
  17. constexpr static int sign = (where>=0) ? PermutationSign<P, typename type::type>::value : 0;
  18. };
  19. // A combination antiC complementary to C wrt [0, 1, ... Dim-1], but permuted to make the permutation [C, antiC] positive with respect to [0, 1, ... Dim-1].
  20. template <class C, int D>
  21. struct AntiCombination
  22. {
  23. using EC = complement<C, D>;
  24. static_assert((len<EC>)>=2, "can't correct this complement");
  25. constexpr static int sign = PermutationSign<append<C, EC>, iota<D>>::value;
  26. // Produce permutation of opposite sign if sign<0.
  27. using type = mp::cons<std::tuple_element_t<(sign<0) ? 1 : 0, EC>,
  28. mp::cons<std::tuple_element_t<(sign<0) ? 0 : 1, EC>,
  29. mp::drop<EC, 2>>>;
  30. };
  31. template <class C, int D> struct MapAntiCombination;
  32. template <int D, class ... C>
  33. struct MapAntiCombination<std::tuple<C ...>, D>
  34. {
  35. using type = std::tuple<typename AntiCombination<C, D>::type ...>;
  36. };
  37. template <int D, int O>
  38. struct ChooseComponents
  39. {
  40. static_assert(D>=O, "bad dimension or form order");
  41. using type = mp::combinations<iota<D>, O>;
  42. };
  43. template <int D, int O> using ChooseComponents_ = typename ChooseComponents<D, O>::type;
  44. template <int D, int O>
  45. requires ((D>1) && (2*O>D))
  46. struct ChooseComponents<D, O>
  47. {
  48. static_assert(D>=O, "bad dimension or form order");
  49. using type = typename MapAntiCombination<ChooseComponents_<D, D-O>, D>::type;
  50. };
  51. // Works *almost* to the range of std::size_t.
  52. constexpr std::size_t
  53. n_over_p(std::size_t const n, std::size_t p)
  54. {
  55. if (p>n) {
  56. return 0;
  57. } else if (p>(n-p)) {
  58. p = n-p;
  59. }
  60. std::size_t v = 1;
  61. for (std::size_t i=0; i!=p; ++i) {
  62. v = v*(n-i)/(i+1);
  63. }
  64. return v;
  65. }
  66. // We form the basis for the result (Cr) and split it in pieces for Oa and Ob; there are (D over Oa) ways. Then we see where and with which signs these pieces are in the bases for Oa (Ca) and Ob (Cb), and form the product.
  67. template <int D, int Oa, int Ob>
  68. struct Wedge
  69. {
  70. constexpr static int Or = Oa+Ob;
  71. static_assert(Oa<=D && Ob<=D && Or<=D, "bad orders");
  72. constexpr static int Na = n_over_p(D, Oa);
  73. constexpr static int Nb = n_over_p(D, Ob);
  74. constexpr static int Nr = n_over_p(D, Or);
  75. // in lexicographic order. Can be used to sort Ca below with FindPermutation.
  76. using LexOrCa = mp::combinations<mp::iota<D>, Oa>;
  77. // the actual components used, which are in lex. order only in some cases.
  78. using Ca = mp::ChooseComponents_<D, Oa>;
  79. using Cb = mp::ChooseComponents_<D, Ob>;
  80. using Cr = mp::ChooseComponents_<D, Or>;
  81. // optimizations.
  82. constexpr static bool yields_expr = (Na>1) != (Nb>1);
  83. constexpr static bool yields_expr_a1 = yields_expr && Na==1;
  84. constexpr static bool yields_expr_b1 = yields_expr && Nb==1;
  85. constexpr static bool both_scalars = (Na==1 && Nb==1);
  86. constexpr static bool dot_plus = Na>1 && Nb>1 && Or==D && (Oa<Ob || (Oa>Ob && !ra::odd(Oa*Ob)));
  87. constexpr static bool dot_minus = Na>1 && Nb>1 && Or==D && (Oa>Ob && ra::odd(Oa*Ob));
  88. constexpr static bool general_case = (Na>1 && Nb>1) && ((Oa+Ob!=D) || (Oa==Ob));
  89. template <class Va, class Vb>
  90. using valtype = std::decay_t<decltype(std::declval<Va>()[0] * std::declval<Vb>()[0])>;
  91. template <class Xr, class Fa, class Va, class Vb>
  92. constexpr static valtype<Va, Vb>
  93. term(Va const & a, Vb const & b)
  94. {
  95. if constexpr (!mp::nilp<Fa>) {
  96. using Fa0 = mp::first<Fa>;
  97. using Fb = mp::complement_list<Fa0, Xr>;
  98. using Sa = mp::FindCombination<Fa0, Ca>;
  99. using Sb = mp::FindCombination<Fb, Cb>;
  100. constexpr int sign = Sa::sign * Sb::sign * mp::PermutationSign<mp::append<Fa0, Fb>, Xr>::value;
  101. static_assert(sign==+1 || sign==-1, "Bad sign in wedge term.");
  102. return valtype<Va, Vb>(sign)*a[Sa::where]*b[Sb::where] + term<Xr, mp::drop1<Fa>>(a, b);
  103. } else {
  104. return 0.;
  105. }
  106. }
  107. template <class Va, class Vb, class Vr, int wr>
  108. constexpr static void
  109. coeff(Va const & a, Vb const & b, Vr & r)
  110. {
  111. if constexpr (wr<Nr) {
  112. using Xr = mp::ref<Cr, wr>;
  113. using Fa = mp::combinations<Xr, Oa>;
  114. r[wr] = term<Xr, Fa>(a, b);
  115. coeff<Va, Vb, Vr, wr+1>(a, b, r);
  116. }
  117. }
  118. template <class Va, class Vb, class Vr>
  119. constexpr static void
  120. product(Va const & a, Vb const & b, Vr & r)
  121. {
  122. static_assert(int(Va::size())==Na, "bad Va dim"); // gcc accepts a.size(), etc.
  123. static_assert(int(Vb::size())==Nb, "bad Vb dim");
  124. static_assert(int(Vr::size())==Nr, "bad Vr dim");
  125. coeff<Va, Vb, Vr, 0>(a, b, r);
  126. }
  127. };
  128. // This is for Euclidean space, it only does component shuffling.
  129. template <int D, int O>
  130. struct Hodge
  131. {
  132. using W = Wedge<D, O, D-O>;
  133. using Ca = typename W::Ca;
  134. using Cb = typename W::Cb;
  135. using Cr = typename W::Cr;
  136. using LexOrCa = typename W::LexOrCa;
  137. constexpr static int Na = W::Na;
  138. constexpr static int Nb = W::Nb;
  139. template <int i, class Va, class Vb>
  140. constexpr static void
  141. hodge_aux(Va const & a, Vb & b)
  142. {
  143. static_assert(i<=W::Na, "Bad argument to hodge_aux");
  144. if constexpr (i<W::Na) {
  145. using Cai = mp::ref<Ca, i>;
  146. static_assert(mp::len<Cai> == O, "bad");
  147. // sort Cai, because mp::complement only accepts sorted combinations.
  148. // ref<Cb, i> should be complementary to Cai, but I don't want to rely on that.
  149. using SCai = mp::ref<LexOrCa, mp::FindCombination<Cai, LexOrCa>::where>;
  150. using CompCai = mp::complement<SCai, D>;
  151. static_assert(mp::len<CompCai> == D-O, "bad");
  152. using fpw = mp::FindCombination<CompCai, Cb>;
  153. // for the sign see e.g. DoCarmo1991 I.Ex 10.
  154. using fps = mp::FindCombination<mp::append<Cai, mp::ref<Cb, fpw::where>>, Cr>;
  155. static_assert(fps::sign!=0, "bad");
  156. b[fpw::where] = decltype(a[i])(fps::sign)*a[i];
  157. hodge_aux<i+1>(a, b);
  158. }
  159. }
  160. };
  161. // The order of components is taken from Wedge<D, O, D-O>; this works for whatever order is defined there.
  162. // With lexicographic order, component order is reversed, but signs vary.
  163. // With the order given by ChooseComponents<>, fpw::where==i and fps::sign==+1 in hodge_aux(), always. Then hodge() becomes a free operation, (with one exception) and the next function hodge() can be used.
  164. template <int D, int O, class Va, class Vb>
  165. constexpr void
  166. hodgex(Va const & a, Vb & b)
  167. {
  168. static_assert(O<=D, "bad orders");
  169. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error"); // gcc accepts a.size(), etc.
  170. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  171. mp::Hodge<D, O>::template hodge_aux<0>(a, b);
  172. }
  173. } // namespace ra::mp
  174. namespace ra {
  175. // This depends on Wedge<>::Ca, Cb, Cr coming from ChooseCombinations, as enforced in test_wedge_product. hodgex() should always work, but this is cheaper.
  176. // However if 2*O=D, it is not possible to differentiate the bases by order and hodgex() must be used.
  177. // Likewise, when O(N-O) is odd, Hodge from (2*O>D) to (2*O<D) change sign, since **w= -w in that case, and the basis in the (2*O>D) case is selected to make Hodge(<)->Hodge(>) trivial; but can't do both!
  178. #define TRIVIAL(D, O) (2*O!=D && ((2*O<D) || !ra::odd(O*(D-O))))
  179. template <int D, int O, class Va, class Vb>
  180. constexpr void
  181. hodge(Va const & a, Vb & b)
  182. {
  183. if constexpr (TRIVIAL(D, O)) {
  184. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error"); // gcc accepts a.size(), etc
  185. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  186. b = a;
  187. } else {
  188. ra::mp::hodgex<D, O>(a, b);
  189. }
  190. }
  191. template <int D, int O, class Va>
  192. requires (TRIVIAL(D, O))
  193. constexpr Va const &
  194. hodge(Va const & a)
  195. {
  196. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error"); // gcc accepts a.size()
  197. return a;
  198. }
  199. template <int D, int O, class Va>
  200. requires (!TRIVIAL(D, O))
  201. constexpr Va &
  202. hodge(Va & a)
  203. {
  204. Va b(a);
  205. ra::mp::hodgex<D, O>(b, a);
  206. return a;
  207. }
  208. #undef TRIVIAL
  209. // --------------------
  210. // Wedge product
  211. // TODO Handle the simplifications dot_plus, yields_scalar, etc. just as vec::wedge does.
  212. // --------------------
  213. template <int D, int Oa, int Ob, class A, class B>
  214. requires (ra::is_scalar<A> && ra::is_scalar<B>)
  215. constexpr auto wedge(A const & a, B const & b) { return a*b; }
  216. template <class A>
  217. struct torank1
  218. {
  219. using type = std::conditional_t<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
  220. };
  221. template <class Wedge, class Va, class Vb>
  222. struct fromrank1
  223. {
  224. using valtype = typename Wedge::template valtype<Va, Vb>;
  225. using type = std::conditional_t<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr>>;
  226. };
  227. #define DECL_WEDGE(condition) \
  228. template <int D, int Oa, int Ob, class Va, class Vb> \
  229. requires (!(is_scalar<Va> && is_scalar<Vb>)) \
  230. decltype(auto) \
  231. wedge(Va const & a, Vb const & b)
  232. DECL_WEDGE(general_case)
  233. {
  234. Small<value_t<Va>, size_s<Va>()> aa = a;
  235. Small<value_t<Vb>, size_s<Vb>()> bb = b;
  236. using Ua = decltype(aa);
  237. using Ub = decltype(bb);
  238. typename fromrank1<mp::Wedge<D, Oa, Ob>, Ua, Ub>::type r;
  239. auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
  240. auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
  241. auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
  242. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  243. return r;
  244. }
  245. #undef DECL_WEDGE
  246. #define DECL_WEDGE(condition) \
  247. template <int D, int Oa, int Ob, class Va, class Vb, class Vr> \
  248. requires (!(is_scalar<Va> && is_scalar<Vb>)) \
  249. void \
  250. wedge(Va const & a, Vb const & b, Vr & r)
  251. DECL_WEDGE(general_case)
  252. {
  253. Small<value_t<Va>, size_s<Va>()> aa = a;
  254. Small<value_t<Vb>, size_s<Vb>()> bb = b;
  255. using Ua = decltype(aa);
  256. using Ub = decltype(bb);
  257. auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
  258. auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
  259. auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
  260. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  261. }
  262. #undef DECL_WEDGE
  263. template <class A, class B> requires (size_s<A>()==2 && size_s<B>()==2)
  264. constexpr auto
  265. cross(A const & a_, B const & b_)
  266. {
  267. Small<std::decay_t<decltype(FLAT(a_))>, 2> a = a_;
  268. Small<std::decay_t<decltype(FLAT(b_))>, 2> b = b_;
  269. Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, 1> r;
  270. mp::Wedge<2, 1, 1>::product(a, b, r);
  271. return r[0];
  272. }
  273. template <class A, class B> requires (size_s<A>()==3 && size_s<B>()==3)
  274. constexpr auto
  275. cross(A const & a_, B const & b_)
  276. {
  277. Small<std::decay_t<decltype(FLAT(a_))>, 3> a = a_;
  278. Small<std::decay_t<decltype(FLAT(b_))>, 3> b = b_;
  279. Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, 3> r;
  280. mp::Wedge<3, 1, 1>::product(a, b, r);
  281. return r;
  282. }
  283. template <class V>
  284. constexpr auto
  285. perp(V const & v)
  286. {
  287. static_assert(v.size()==2, "dimension error");
  288. return Small<std::decay_t<decltype(FLAT(v))>, 2> {v[1], -v[0]};
  289. }
  290. template <class V, class U>
  291. constexpr auto
  292. perp(V const & v, U const & n)
  293. {
  294. if constexpr (is_scalar<U>) {
  295. static_assert(v.size()==2, "dimension error");
  296. return Small<std::decay_t<decltype(FLAT(v) * n)>, 2> {v[1]*n, -v[0]*n};
  297. } else {
  298. static_assert(v.size()==3, "dimension error");
  299. return cross(v, n);
  300. }
  301. }
  302. } // namespace ra