dual.cc 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra/test - Dual numbers.
  3. // (c) Daniel Llorens - 2013, 2015
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #include <cassert>
  9. #include <numeric>
  10. #include <iostream>
  11. #include <algorithm>
  12. #include "ra/test.hh"
  13. #include "ra/dual.hh"
  14. using std::cout, std::endl, std::flush, ra::TestRecorder;
  15. using real = double;
  16. using complex = std::complex<double>;
  17. using ra::dual, ra::Dual;
  18. using ra::sqr, ra::fma;
  19. #define DEFINE_CASE(N, F, DF) \
  20. struct JOIN(case, N) \
  21. { \
  22. template <class X> static auto f(X x) { return (F); } \
  23. template <class X> static auto df(X x) { return (DF); } \
  24. };
  25. DEFINE_CASE(0, x*cos(x)/sqrt(x),
  26. cos(x)/(2.*sqrt(x))-sqrt(x)*sin(x))
  27. DEFINE_CASE(1, x,
  28. 1.)
  29. DEFINE_CASE(2, 3.*exp(x*x)/x+8.*exp(2.*x)/x,
  30. -3.*exp(x*x)/(x*x)+6.*exp(x*x)+16.*exp(2.*x)/x-8.*exp(2.*x)/(x*x))
  31. DEFINE_CASE(3, cos(pow(exp(x), 4.5)),
  32. -4.5*exp(4.5*x)*sin(exp(4.5*x)))
  33. DEFINE_CASE(4, 1./(x*x),
  34. -2.*x/sqr(x*x))
  35. DEFINE_CASE(5, 1./(2.-x*x),
  36. +2.*x/sqr(2.-x*x))
  37. DEFINE_CASE(6, sinh(x)/cosh(x),
  38. 1./sqr(cosh(x)))
  39. DEFINE_CASE(7, fma(x, x, 3.*x),
  40. 2.*x+3.)
  41. #undef DEFINE_CASE
  42. // repeat case2 using assignment ops.
  43. struct case8
  44. {
  45. template <class X> static auto f(X x)
  46. {
  47. auto y = 3.*exp(x*x);
  48. y /= x;
  49. y += 8.*exp(2.*x)/x;
  50. return y;
  51. }
  52. template <class X> static auto df(X x)
  53. {
  54. auto lo = x;
  55. lo *= lo;
  56. auto dy = -3.*exp(x*x)/lo;
  57. dy += +6.*exp(x*x);
  58. dy += +16.*exp(2.*x)/x;
  59. dy -= 8.*exp(2.*x)/lo;
  60. return dy;
  61. }
  62. };
  63. template <class Case, class X>
  64. void test1(TestRecorder & tr, std::string info, X && x, real const rspec=2e-15)
  65. {
  66. for (unsigned int i=0; i!=x.size(); ++i) {
  67. tr.info(info, " ", i, " f vs Dual").test_rel(Case::f(x[i]), Case::f(dual(x[i], 1.)).re, rspec);
  68. tr.info(info, " ", i, " df vs Dual").test_rel(Case::df(x[i]), Case::f(dual(x[i], 1.)).du, rspec);
  69. }
  70. }
  71. int main()
  72. {
  73. TestRecorder tr(std::cout);
  74. tr.test_eq(0., Dual<real>{3}.du);
  75. tr.test_eq(0., dual(3.).du);
  76. tr.section("tests with real");
  77. {
  78. std::vector<real> x(10);
  79. std::iota(x.begin(), x.end(), 1);
  80. for (real & xi: x) { xi *= .1; }
  81. test1<case0>(tr, "case0", x);
  82. test1<case1>(tr, "case1", x);
  83. test1<case2>(tr, "case2", x);
  84. test1<case3>(tr, "case3", x, 3e-14);
  85. test1<case4>(tr, "case4", x, 1e-15);
  86. test1<case5>(tr, "case5", x, 1e-15);
  87. test1<case6>(tr, "case6", x, 1e-15);
  88. test1<case7>(tr, "case7", x);
  89. test1<case8>(tr, "case8", x);
  90. }
  91. tr.section("demo with complex");
  92. {
  93. Dual<complex> x { complex(3, 1), 1. };
  94. cout << x << endl;
  95. cout << exp(x) << endl;
  96. cout << (x*x) << endl;
  97. }
  98. tr.section("real -> dual<complex> conversion");
  99. {
  100. Dual<complex> x { 1., 1. };
  101. x = 0.;
  102. tr.test_eq(0., x.re);
  103. tr.test_eq(0., x.du);
  104. }
  105. tr.section("tests with complex");
  106. {
  107. std::vector<complex> x(10);
  108. std::iota(x.begin(), x.end(), 1);
  109. for (complex & xi: x) { xi = xi*.1 + complex(0, 1); }
  110. test1<case0>(tr, "case0", x);
  111. test1<case1>(tr, "case1", x);
  112. test1<case2>(tr, "case2", x);
  113. test1<case3>(tr, "case3", x, 5e-14);
  114. test1<case4>(tr, "case4", x, 1e-15);
  115. test1<case5>(tr, "case5", x, 1e-15);
  116. test1<case6>(tr, "case6", x, 1.2e-15);
  117. test1<case7>(tr, "case7", x);
  118. test1<case8>(tr, "case8", x);
  119. }
  120. return tr.summary();
  121. }