dual.hh 6.0 KB


  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Dual numbers for automatic differentiation.
  3. // (c) Daniel Llorens - 2013-2024
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // See VanderBergen2012, Berland2006. Generally about automatic differentiation:
  9. // http://en.wikipedia.org/wiki/Automatic_differentiation
  10. // From the Taylor expansion of f(a) or f(a, b)...
  11. // f(a+εa') = f(a)+εa'f_a(a)
  12. // f(a+εa', b+εb') = f(a, b)+ε[a'f_a(a, b) b'f_b(a, b)]
  13. #pragma once
  14. #include <cmath>
  15. #include <iosfwd>
  16. #include "tuples.hh"
  17. namespace ra {
  18. using std::abs, std::sqrt, std::fma;
  19. template <class T>
  20. struct Dual
  21. {
  22. T re, du;
  23. constexpr static bool is_complex = requires { requires !(std::is_same_v<T, std::decay_t<decltype(std::declval<T>().real())>>); };
  24. template <class S> struct real_part { struct type {}; };
  25. template <class S> requires (is_complex) struct real_part<S> { using type = typename S::value_type; };
  26. using real_type = typename real_part<T>::type;
  27. constexpr Dual(T const & r, T const & d): re(r), du(d) {}
  28. constexpr Dual(T const & r): re(r), du(0.) {} // conversions are by default constants.
  29. constexpr Dual(real_type const & r) requires (is_complex): re(r), du(0.) {}
  30. constexpr Dual() {}
  31. #define ASSIGNOPS(OP) \
  32. constexpr Dual & operator JOIN(OP, =)(T const & r) { *this = *this OP r; return *this; } \
  33. constexpr Dual & operator JOIN(OP, =)(Dual const & r) { *this = *this OP r; return *this; } \
  34. constexpr Dual & operator JOIN(OP, =)(real_type const & r) requires (is_complex) { *this = *this OP r; return *this; }
  35. FOR_EACH(ASSIGNOPS, +, -, /, *)
  36. #undef ASSIGNOPS
  37. };
  38. // conversions are by default constants.
  39. template <class R> constexpr auto dual(Dual<R> const & r) { return r; }
  40. template <class R> constexpr auto dual(R const & r) { return Dual<R> { r, 0. }; }
  41. template <class R, class D>
  42. constexpr auto
  43. dual(R const & r, D const & d)
  44. {
  45. return Dual<std::common_type_t<R, D>> { r, d };
  46. }
  47. template <class A, class B>
  48. constexpr auto
  49. operator*(Dual<A> const & a, Dual<B> const & b)
  50. {
  51. return dual(a.re*b.re, a.re*b.du + a.du*b.re);
  52. }
  53. template <class A, class B>
  54. constexpr auto
  55. operator*(A const & a, Dual<B> const & b)
  56. {
  57. return dual(a*b.re, a*b.du);
  58. }
  59. template <class A, class B>
  60. constexpr auto
  61. operator*(Dual<A> const & a, B const & b)
  62. {
  63. return dual(a.re*b, a.du*b);
  64. }
  65. template <class A, class B, class C>
  66. constexpr auto
  67. fma(Dual<A> const & a, Dual<B> const & b, Dual<C> const & c)
  68. {
  69. return dual(::fma(a.re, b.re, c.re), ::fma(a.re, b.du, ::fma(a.du, b.re, c.du))); // FIXME shouldn't need ::
  70. }
  71. template <class A, class B>
  72. constexpr auto
  73. operator+(Dual<A> const & a, Dual<B> const & b)
  74. {
  75. return dual(a.re+b.re, a.du+b.du);
  76. }
  77. template <class A, class B>
  78. constexpr auto
  79. operator+(A const & a, Dual<B> const & b)
  80. {
  81. return dual(a+b.re, b.du);
  82. }
  83. template <class A, class B>
  84. constexpr auto
  85. operator+(Dual<A> const & a, B const & b)
  86. {
  87. return dual(a.re+b, a.du);
  88. }
  89. template <class A, class B>
  90. constexpr auto
  91. operator-(Dual<A> const & a, Dual<B> const & b)
  92. {
  93. return dual(a.re-b.re, a.du-b.du);
  94. }
  95. template <class A, class B>
  96. constexpr auto
  97. operator-(Dual<A> const & a, B const & b)
  98. {
  99. return dual(a.re-b, a.du);
  100. }
  101. template <class A, class B>
  102. constexpr auto
  103. operator-(A const & a, Dual<B> const & b)
  104. {
  105. return dual(a-b.re, -b.du);
  106. }
  107. template <class A>
  108. constexpr auto
  109. operator-(Dual<A> const & a)
  110. {
  111. return dual(-a.re, -a.du);
  112. }
  113. template <class A>
  114. constexpr decltype(auto)
  115. operator+(Dual<A> const & a)
  116. {
  117. return a;
  118. }
  119. template <class A>
  120. constexpr auto
  121. inv(Dual<A> const & a)
  122. {
  123. auto i = 1./a.re;
  124. return dual(i, -a.du*(i*i));
  125. }
  126. template <class A, class B>
  127. constexpr auto
  128. operator/(Dual<A> const & a, Dual<B> const & b)
  129. {
  130. return a*inv(b);
  131. }
  132. template <class A, class B>
  133. constexpr auto
  134. operator/(Dual<A> const & a, B const & b)
  135. {
  136. return a*inv(dual(b));
  137. }
  138. template <class A, class B>
  139. constexpr auto
  140. operator/(A const & a, Dual<B> const & b)
  141. {
  142. return dual(a)*inv(b);
  143. }
  144. template <class A>
  145. constexpr auto
  146. cos(Dual<A> const & a)
  147. {
  148. return dual(cos(a.re), -sin(a.re)*a.du);
  149. }
  150. template <class A>
  151. constexpr auto
  152. sin(Dual<A> const & a)
  153. {
  154. return dual(sin(a.re), +cos(a.re)*a.du);
  155. }
  156. template <class A>
  157. constexpr auto
  158. cosh(Dual<A> const & a)
  159. {
  160. return dual(cosh(a.re), +sinh(a.re)*a.du);
  161. }
  162. template <class A>
  163. constexpr auto
  164. sinh(Dual<A> const & a)
  165. {
  166. return dual(sinh(a.re), +cosh(a.re)*a.du);
  167. }
  168. template <class A>
  169. constexpr auto
  170. tan(Dual<A> const & a)
  171. {
  172. auto c = cos(a.du);
  173. return dual(tan(a.re), a.du/(c*c));
  174. }
  175. template <class A>
  176. constexpr auto
  177. exp(Dual<A> const & a)
  178. {
  179. return dual(exp(a.re), +exp(a.re)*a.du);
  180. }
  181. template <class A, class B>
  182. constexpr auto
  183. pow(Dual<A> const & a, B const & b)
  184. {
  185. return dual(pow(a.re, b), +b*pow(a.re, b-1)*a.du);
  186. }
  187. template <class A>
  188. constexpr auto
  189. log(Dual<A> const & a)
  190. {
  191. return dual(log(a.re), +a.du/a.re);
  192. }
  193. template <class A>
  194. constexpr auto
  195. sqrt(Dual<A> const & a)
  196. {
  197. return dual(sqrt(a.re), +a.du/(2.*sqrt(a.re)));
  198. }
  199. template <class A>
  200. constexpr auto
  201. sqr(Dual<A> const & a)
  202. {
  203. return a*a;
  204. }
  205. template <class A>
  206. constexpr auto
  207. abs(Dual<A> const & a)
  208. {
  209. return abs(a.re);
  210. }
  211. template <class A>
  212. constexpr bool
  213. isfinite(Dual<A> const & a)
  214. {
  215. return isfinite(a.re) && isfinite(a.du);
  216. }
  217. template <class A>
  218. constexpr auto
  219. xi(Dual<A> const & a)
  220. {
  221. return dual(xi(a.re), xi(a.du));
  222. }
  223. template <class A>
  224. std::ostream & operator<<(std::ostream & o, Dual<A> const & a)
  225. {
  226. return o << "[" << a.re << " " << a.du << "]";
  227. }
  228. template <class A>
  229. std::istream & operator>>(std::istream & i, Dual<A> & a)
  230. {
  231. char s;
  232. i >> s;
  233. if (s!='[') {
  234. i.setstate(std::ios::failbit);
  235. } else {
  236. i >> a.re >> a.du >> s;
  237. if (s!=']') {
  238. i.setstate(std::ios::failbit);
  239. }
  240. }
  241. return i;
  242. }
  243. } // namespace ra