NUMinterpol.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. /* NUMinterpol.cpp
  2. *
  3. * Copyright (C) 1992-2008,2011,2012,2014,2015,2017,2018 Paul Boersma
  4. *
  5. * This code is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation; either version 2 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This code is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. * See the GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this work. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. /*
  19. * pb 2002/03/07 GPL
  20. * pb 2003/06/19 ridders3 replaced with ridders
  21. * pb 2003/07/09 gsl
  22. * pb 2007/01/27 use #defines for value interpolation
  23. * pb 2007/08/20 built a "weird value" check into NUMviterbi (bug report by Adam Jacks)
  24. * pb 2011/03/29 C++
  25. */
  26. #include "melder.h"
  27. #include "../dwsys/NUM2.h"
  28. #define SIGN(x,s) ((s) < 0 ? -fabs (x) : fabs(x))
  29. #define NUM_interpolate_simple_cases \
  30. if (nx < 1) return undefined; \
  31. if (x > nx) return y [nx]; \
  32. if (x < 1) return y [1]; \
  33. if (x == midleft) return y [midleft]; \
  34. /* 1 < x < nx && x not integer: interpolate. */ \
  35. if (maxDepth > midright - 1) maxDepth = midright - 1; \
  36. if (maxDepth > nx - midleft) maxDepth = nx - midleft; \
  37. if (maxDepth <= NUM_VALUE_INTERPOLATE_NEAREST) return y [(integer) floor (x + 0.5)]; \
  38. if (maxDepth == NUM_VALUE_INTERPOLATE_LINEAR) return y [midleft] + (x - midleft) * (y [midright] - y [midleft]); \
  39. if (maxDepth == NUM_VALUE_INTERPOLATE_CUBIC) { \
  40. double yl = y [midleft], yr = y [midright]; \
  41. double dyl = 0.5 * (yr - y [midleft - 1]), dyr = 0.5 * (y [midright + 1] - yl); \
  42. double fil = x - midleft, fir = midright - x; \
  43. return yl * fir + yr * fil - fil * fir * (0.5 * (dyr - dyl) + (fil - 0.5) * (dyl + dyr - 2 * (yr - yl))); \
  44. }
  45. #if defined (__POWERPC__)
  46. double NUM_interpolate_sinc (double y [], integer nx, double x, integer maxDepth) {
  47. integer ix, midleft = (integer) floor (x), midright = midleft + 1, left, right;
  48. double result = 0.0, a, halfsina, aa, daa, cosaa, sinaa, cosdaa, sindaa;
  49. NUM_interpolate_simple_cases
  50. left = midright - maxDepth, right = midleft + maxDepth;
  51. a = NUMpi * (x - midleft);
  52. halfsina = 0.5 * sin (a);
  53. aa = a / (x - left + 1); cosaa = cos (aa); sinaa = sin (aa);
  54. daa = NUMpi / (x - left + 1); cosdaa = cos (daa); sindaa = sin (daa);
  55. for (ix = midleft; ix >= left; ix --) {
  56. double d = halfsina / a * (1.0 + cosaa), help;
  57. result += y [ix] * d;
  58. a += NUMpi;
  59. help = cosaa * cosdaa - sinaa * sindaa;
  60. sinaa = cosaa * sindaa + sinaa * cosdaa;
  61. cosaa = help;
  62. halfsina = - halfsina;
  63. }
  64. a = NUMpi * (midright - x);
  65. halfsina = 0.5 * sin (a);
  66. aa = a / (right - x + 1); cosaa = cos (aa); sinaa = sin (aa);
  67. daa = NUMpi / (right - x + 1); cosdaa = cos (daa); sindaa = sin (daa);
  68. for (ix = midright; ix <= right; ix ++) {
  69. double d = halfsina / a * (1.0 + cosaa), help;
  70. result += y [ix] * d;
  71. a += NUMpi;
  72. help = cosaa * cosdaa - sinaa * sindaa;
  73. sinaa = cosaa * sindaa + sinaa * cosdaa;
  74. cosaa = help;
  75. halfsina = - halfsina;
  76. }
  77. return result;
  78. }
  79. #else
  80. double NUM_interpolate_sinc (double y [], integer nx, double x, integer maxDepth) {
  81. integer ix, midleft = (integer) floor (x), midright = midleft + 1, left, right;
  82. double result = 0.0, a, halfsina, aa, daa;
  83. NUM_interpolate_simple_cases
  84. left = midright - maxDepth;
  85. right = midleft + maxDepth;
  86. a = NUMpi * (x - midleft);
  87. halfsina = 0.5 * sin (a);
  88. aa = a / (x - left + 1);
  89. daa = NUMpi / (x - left + 1);
  90. for (ix = midleft; ix >= left; ix --) {
  91. double d = halfsina / a * (1.0 + cos (aa));
  92. result += y [ix] * d;
  93. a += NUMpi;
  94. aa += daa;
  95. halfsina = - halfsina;
  96. }
  97. a = NUMpi * (midright - x);
  98. halfsina = 0.5 * sin (a);
  99. aa = a / (right - x + 1);
  100. daa = NUMpi / (right - x + 1); \
  101. for (ix = midright; ix <= right; ix ++) {
  102. double d = halfsina / a * (1.0 + cos (aa));
  103. result += y [ix] * d;
  104. a += NUMpi;
  105. aa += daa;
  106. halfsina = - halfsina;
  107. }
  108. return result;
  109. }
  110. #endif
  111. /********** Improving extrema **********/
  112. #pragma mark Improving extrema
  113. struct improve_params {
  114. int depth;
  115. double *y;
  116. integer ixmax;
  117. int isMaximum;
  118. };
  119. static double improve_evaluate (double x, void *closure) {
  120. struct improve_params *me = (struct improve_params *) closure;
  121. double y = NUM_interpolate_sinc (my y, my ixmax, x, my depth);
  122. return my isMaximum ? - y : y;
  123. }
  124. double NUMimproveExtremum (double *y, integer nx, integer ixmid, int interpolation, double *ixmid_real, int isMaximum) {
  125. struct improve_params params;
  126. double result;
  127. if (ixmid <= 1) { *ixmid_real = 1; return y [1]; }
  128. if (ixmid >= nx) { *ixmid_real = nx; return y [nx]; }
  129. if (interpolation <= NUM_PEAK_INTERPOLATE_NONE) { *ixmid_real = ixmid; return y [ixmid]; }
  130. if (interpolation == NUM_PEAK_INTERPOLATE_PARABOLIC) {
  131. double dy = 0.5 * (y [ixmid + 1] - y [ixmid - 1]);
  132. double d2y = 2 * y [ixmid] - y [ixmid - 1] - y [ixmid + 1];
  133. *ixmid_real = ixmid + dy / d2y;
  134. return y [ixmid] + 0.5 * dy * dy / d2y;
  135. }
  136. /* Sinc interpolation. */
  137. params. y = y;
  138. params. depth = interpolation == NUM_PEAK_INTERPOLATE_SINC70 ? 70 : 700;
  139. params. ixmax = nx;
  140. params. isMaximum = isMaximum;
  141. /*return isMaximum ?
  142. - NUM_minimize (ixmid - 1, ixmid, ixmid + 1, improve_evaluate, & params, 1e-10, 1e-11, ixmid_real) :
  143. NUM_minimize (ixmid - 1, ixmid, ixmid + 1, improve_evaluate, & params, 1e-10, 1e-11, ixmid_real);*/
  144. *ixmid_real = NUMminimize_brent (improve_evaluate, ixmid - 1, ixmid + 1, & params, 1e-10, & result);
  145. return isMaximum ? - result : result;
  146. }
  147. double NUMimproveMaximum (double *y, integer nx, integer ixmid, int interpolation, double *ixmid_real)
  148. { return NUMimproveExtremum (y, nx, ixmid, interpolation, ixmid_real, 1); }
  149. double NUMimproveMinimum (double *y, integer nx, integer ixmid, int interpolation, double *ixmid_real)
  150. { return NUMimproveExtremum (y, nx, ixmid, interpolation, ixmid_real, 0); }
  151. /********** Viterbi **********/
  152. void NUM_viterbi (
  153. integer numberOfFrames, integer maxnCandidates,
  154. integer (*getNumberOfCandidates) (integer iframe, void *closure),
  155. double (*getLocalCost) (integer iframe, integer icand, void *closure),
  156. double (*getTransitionCost) (integer iframe, integer icand1, integer icand2, void *closure),
  157. void (*putResult) (integer iframe, integer place, void *closure),
  158. void *closure)
  159. {
  160. autoMAT delta = MATraw (numberOfFrames, maxnCandidates);
  161. autoINTMAT psi = INTMATraw (numberOfFrames, maxnCandidates);
  162. autoINTVEC numberOfCandidates = INTVECraw (numberOfFrames);
  163. for (integer iframe = 1; iframe <= numberOfFrames; iframe ++) {
  164. numberOfCandidates [iframe] = getNumberOfCandidates (iframe, closure);
  165. for (integer icand = 1; icand <= numberOfCandidates [iframe]; icand ++)
  166. delta [iframe] [icand] = - getLocalCost (iframe, icand, closure);
  167. }
  168. for (integer iframe = 2; iframe <= numberOfFrames; iframe ++) {
  169. for (integer icand2 = 1; icand2 <= numberOfCandidates [iframe]; icand2 ++) {
  170. double maximum = -1e308;
  171. integer place = 0;
  172. for (integer icand1 = 1; icand1 <= numberOfCandidates [iframe - 1]; icand1 ++) {
  173. double value = delta [iframe - 1] [icand1] + delta [iframe] [icand2]
  174. - getTransitionCost (iframe, icand1, icand2, closure);
  175. if (value > maximum) { maximum = value; place = icand1; }
  176. }
  177. if (place == 0)
  178. Melder_throw (U"Viterbi algorithm cannot compute a track because of weird values.");
  179. delta [iframe] [icand2] = maximum;
  180. psi [iframe] [icand2] = place;
  181. }
  182. }
  183. /*
  184. Find the end of the most probable path.
  185. */
  186. integer place;
  187. double maximum = delta [numberOfFrames] [place = 1];
  188. for (integer icand = 2; icand <= numberOfCandidates [numberOfFrames]; icand ++) {
  189. if (delta [numberOfFrames] [icand] > maximum)
  190. maximum = delta [numberOfFrames] [place = icand];
  191. }
  192. /*
  193. Backtrack.
  194. */
  195. for (integer iframe = numberOfFrames; iframe >= 1; iframe --) {
  196. putResult (iframe, place, closure);
  197. place = psi [iframe] [place];
  198. }
  199. }
  200. /******************/
  201. struct parm2 {
  202. integer ntrack;
  203. integer ncomb;
  204. INTMAT indices;
  205. double (*getLocalCost) (integer iframe, integer icand, integer itrack, void *closure);
  206. double (*getTransitionCost) (integer iframe, integer icand1, integer icand2, integer itrack, void *closure);
  207. void (*putResult) (integer iframe, integer place, integer itrack, void *closure);
  208. void *closure;
  209. };
  210. static integer getNumberOfCandidates_n (integer iframe, void *closure) {
  211. struct parm2 *me = (struct parm2 *) closure;
  212. (void) iframe;
  213. return my ncomb;
  214. }
  215. static double getLocalCost_n (integer iframe, integer jcand, void *closure) {
  216. struct parm2 *me = (struct parm2 *) closure;
  217. double localCost = 0.0;
  218. for (integer itrack = 1; itrack <= my ntrack; itrack ++)
  219. localCost += my getLocalCost (iframe, my indices [jcand] [itrack], itrack, my closure);
  220. return localCost;
  221. }
  222. static double getTransitionCost_n (integer iframe, integer jcand1, integer jcand2, void *closure) {
  223. struct parm2 *me = (struct parm2 *) closure;
  224. double transitionCost = 0.0;
  225. for (integer itrack = 1; itrack <= my ntrack; itrack ++)
  226. transitionCost += my getTransitionCost (iframe,
  227. my indices [jcand1] [itrack], my indices [jcand2] [itrack], itrack, my closure);
  228. return transitionCost;
  229. }
  230. static void putResult_n (integer iframe, integer jplace, void *closure) {
  231. struct parm2 *me = (struct parm2 *) closure;
  232. for (integer itrack = 1; itrack <= my ntrack; itrack ++)
  233. my putResult (iframe, my indices [jplace] [itrack], itrack, my closure);
  234. }
  235. void NUM_viterbi_multi (
  236. integer nframe, integer ncand, integer ntrack,
  237. double (*getLocalCost) (integer iframe, integer icand, integer itrack, void *closure),
  238. double (*getTransitionCost) (integer iframe, integer icand1, integer icand2, integer itrack, void *closure),
  239. void (*putResult) (integer iframe, integer place, integer itrack, void *closure),
  240. void *closure)
  241. {
  242. struct parm2 parm;
  243. if (ntrack > ncand) Melder_throw (U"(NUM_viterbi_multi:) "
  244. U"Number of tracks (", ntrack, U") should not exceed number of candidates (", ncand, U").");
  245. integer ncomb = Melder_iround (NUMcombinations (ncand, ntrack));
  246. if (ncomb > 10'000'000) Melder_throw (U"(NUM_viterbi_multi:) "
  247. U"Unrealistically high number of combinations (", ncomb, U").");
  248. parm. ntrack = ntrack;
  249. parm. ncomb = ncomb;
  250. /*
  251. For ncand == 5 and ntrack == 3, parm.indices is going to contain:
  252. 1 2 3
  253. 1 2 4
  254. 1 2 5
  255. 1 3 4
  256. 1 3 5
  257. 1 4 5
  258. 2 3 4
  259. 2 3 5
  260. 2 4 5
  261. 3 4 5
  262. */
  263. autoINTMAT indices = INTMATzero (ncomb, ntrack);
  264. autoINTVEC icand = INTVECraw (ntrack);
  265. for (integer itrack = 1; itrack <= ntrack; itrack ++)
  266. icand [itrack] = itrack; // start out with "1 2 3"
  267. integer jcomb = 0;
  268. for (;;) {
  269. jcomb ++;
  270. for (integer itrack = 1; itrack <= ntrack; itrack ++)
  271. indices [jcomb] [itrack] = icand [itrack];
  272. integer itrack = ntrack;
  273. for (; itrack >= 1; itrack --) {
  274. if (++ icand [itrack] <= ncand - (ntrack - itrack)) {
  275. for (integer jtrack = itrack + 1; jtrack <= ntrack; jtrack ++)
  276. icand [jtrack] = icand [itrack] + jtrack - itrack;
  277. break;
  278. }
  279. }
  280. if (itrack == 0) break;
  281. }
  282. Melder_assert (jcomb == ncomb);
  283. parm. indices = indices.get();
  284. parm. getLocalCost = getLocalCost;
  285. parm. getTransitionCost = getTransitionCost;
  286. parm. putResult = putResult;
  287. parm. closure = closure;
  288. NUM_viterbi (nframe, ncomb, getNumberOfCandidates_n, getLocalCost_n, getTransitionCost_n, putResult_n, & parm);
  289. }
  290. /* End of file NUMinterpol.cpp */