HMM.cpp 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639
  1. /* HMM.cpp
  2. *
  3. * Copyright (C) 2010-2017 David Weenink, 2015,2017 Paul Boersma
  4. *
  5. * This code is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation; either version 2 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This code is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  13. * General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this work. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. /*
  19. djmw 20110304 Thing_new
  20. */
  21. #include "Distributions_and_Strings.h"
  22. #include "HMM.h"
  23. #include "Index.h"
  24. #include "NUM2.h"
  25. #include "Strings_extensions.h"
  26. #include "oo_DESTROY.h"
  27. #include "HMM_def.h"
  28. #include "oo_COPY.h"
  29. #include "HMM_def.h"
  30. #include "oo_EQUAL.h"
  31. #include "HMM_def.h"
  32. #include "oo_CAN_WRITE_AS_ENCODING.h"
  33. #include "HMM_def.h"
  34. #include "oo_WRITE_TEXT.h"
  35. #include "HMM_def.h"
  36. #include "oo_WRITE_BINARY.h"
  37. #include "HMM_def.h"
  38. #include "oo_READ_TEXT.h"
  39. #include "HMM_def.h"
  40. #include "oo_READ_BINARY.h"
  41. #include "HMM_def.h"
  42. #include "oo_DESCRIPTION.h"
  43. #include "HMM_def.h"
  44. Thing_implement (HMM, Daata, 0);
  45. Thing_implement (HMMState, Daata, 0);
  46. Thing_implement (HMMStateList, Ordered, 0);
  47. Thing_implement (HMMObservation, Daata, 0);
  48. Thing_implement (HMMObservationList, Ordered, 0);
  49. Thing_implement (HMMBaumWelch, Daata, 0);
  50. Thing_implement (HMMViterbi, Daata, 0);
  51. Thing_implement (HMMObservationSequence, Table, 0);
  52. Thing_implement (HMMObservationSequenceBag, Collection, 0);
  53. Thing_implement (HMMStateSequence, Strings, 0);
  54. /*
  55. Whenever a routine returns ln(p), the result for p=0 is -INFINITY.
  56. On IEEE floating point hardware this number behaves reasonably.
  57. This means that when the variable q equals INFINITY, q + a -> INFINITY,
  58. where a is a finite number.
  59. */
  60. // helpers
  61. int NUMget_line_intersection_with_circle (double xc, double yc, double r, double a, double b, double *x1, double *y1, double *x2, double *y2);
  62. autoHMMObservation HMMObservation_create (conststring32 label, integer numberOfComponents, integer dimension, integer storage);
  63. integer HMM_HMMObservationSequence_getLongestSequence (HMM me, HMMObservationSequence thee, integer symbolNumber);
  64. integer StringsIndex_getLongestSequence (StringsIndex me, integer index, integer *pos);
  65. integer Strings_getLongestSequence (Strings me, char32 *string, integer *pos);
  66. autoHMMState HMMState_create (conststring32 label);
  67. autoHMMBaumWelch HMMBaumWelch_create (integer nstates, integer nsymbols, integer capacity);
  68. void HMMBaumWelch_getGamma (HMMBaumWelch me);
  69. autoHMMBaumWelch HMM_forward (HMM me, integer *obs, integer nt);
  70. void HMMBaumWelch_reInit (HMMBaumWelch me);
  71. void HMM_HMMBaumWelch_getXi (HMM me, HMMBaumWelch thee, integer *obs);
  72. void HMM_HMMBaumWelch_reestimate (HMM me, HMMBaumWelch thee);
  73. void HMM_HMMBaumWelch_addEstimate (HMM me, HMMBaumWelch thee, integer *obs);
  74. void HMM_HMMBaumWelch_forward (HMM me, HMMBaumWelch thee, integer *obs);
  75. void HMM_HMMBaumWelch_backward (HMM me, HMMBaumWelch thee, integer *obs);
  76. void HMM_HMMViterbi_decode (HMM me, HMMViterbi thee, integer *obs);
  77. double HMM_getProbabilityOfObservations (HMM me, integer *obs, integer numberOfTimes);
  78. autoTableOfReal StringsIndex_to_TableOfReal_transitions (StringsIndex me, int probabilities);
  79. autoStringsIndex HMM_HMMStateSequence_to_StringsIndex (HMM me, HMMStateSequence thee);
  80. autoHMMViterbi HMMViterbi_create (integer nstates, integer ntimes);
  81. autoHMMViterbi HMM_to_HMMViterbi (HMM me, integer *obs, integer ntimes);
  82. // evaluate the numbers given to probabilities
  83. static autoVEC NUMwstring_to_probs (conststring32 s, integer nwanted) {
  84. autoVEC numbers = VEC_createFromString (s);
  85. if (numbers.size != nwanted)
  86. Melder_throw (U"You supplied ", numbers.size, U", while ", nwanted, U" numbers needed.");
  87. longdouble sum = 0.0;
  88. for (integer i = 1; i <= numbers.size; i ++) {
  89. if (numbers [i] < 0.0)
  90. Melder_throw (U"Numbers have to be positive.");
  91. sum += numbers [i];
  92. }
  93. if (sum <= 0.0)
  94. Melder_throw (U"All probabilities cannot be zero.");
  95. for (integer i = 1; i <= numbers.size; i ++)
  96. numbers [i] /= sum;
  97. return numbers;
  98. }
  99. int NUMget_line_intersection_with_circle (double xc, double yc, double r, double a, double b, double *x1, double *y1, double *x2, double *y2) {
  100. double ca = a * a + 1.0, bmyc = (b - yc);
  101. double cb = 2.0 * (a * bmyc - xc);
  102. double cc = bmyc * bmyc + xc * xc - r * r;
  103. integer nroots = NUMsolveQuadraticEquation (ca, cb, cc, x1, x2);
  104. if (nroots == 1) {
  105. *y1 = a * *x1 + b;
  106. *x2 = *x1;
  107. *y2 = *y1;
  108. } else if (nroots == 2) {
  109. if (*x1 > *x2) {
  110. double tmp = *x1;
  111. *x1 = *x2;
  112. *x2 = tmp;
  113. }
  114. *y1 = *x1 * a + b;
  115. *y2 = *x2 * a + b;
  116. }
  117. return nroots;
  118. }
  119. // D(l_1,l_2)=1/n( log p(O_2|l_1) - log p(O_2|l_2)
  120. static double HMM_HMM_getCrossEntropy_asym (HMM me, HMM thee, integer observationLength) {
  121. autoHMMObservationSequence os = HMM_to_HMMObservationSequence (thee, 0, observationLength);
  122. double ce = HMM_HMMObservationSequence_getCrossEntropy (me, os.get());
  123. if (isundef (ce)) {
  124. return ce;
  125. }
  126. double ce2 = HMM_HMMObservationSequence_getCrossEntropy (thee, os.get());
  127. if (isundef (ce2)) {
  128. return ce2;
  129. }
  130. return ce - ce2;
  131. }
  132. /**************** HMMObservation ******************************/
  133. static void HMMObservation_init (HMMObservation me, conststring32 label, integer numberOfComponents, integer dimension, integer storage) {
  134. my label = Melder_dup (label);
  135. my gm = GaussianMixture_create (numberOfComponents, dimension, storage);
  136. }
  137. autoHMMObservation HMMObservation_create (conststring32 label, integer numberOfComponents, integer dimension, integer storage) {
  138. try {
  139. autoHMMObservation me = Thing_new (HMMObservation);
  140. HMMObservation_init (me.get(), label, numberOfComponents, dimension, storage);
  141. return me;
  142. } catch (MelderError) {
  143. Melder_throw (U"HMMObservation not created.");
  144. }
  145. }
  146. integer Strings_getLongestSequence (Strings me, char32 *string, integer *pos) {
  147. integer length = 0, longest = 0, lpos = 0;
  148. for (integer i = 1; i <= my numberOfStrings; i ++) {
  149. if (Melder_equ (my strings [i].get(), string)) {
  150. if (length == 0) {
  151. lpos = i;
  152. }
  153. length ++;
  154. } else {
  155. if (length > 0) {
  156. if (length > longest) {
  157. longest = length; *pos = lpos;
  158. }
  159. length = 0;
  160. }
  161. }
  162. }
  163. return length;
  164. }
  165. integer StringsIndex_getLongestSequence (StringsIndex me, integer index, integer *pos) {
  166. integer length = 0, longest = 0, lpos = 0;
  167. for (integer i = 1; i <= my numberOfItems; i ++) {
  168. if (my classIndex [i] == index) {
  169. if (length == 0)
  170. lpos = i;
  171. length ++;
  172. } else {
  173. if (length > 0) {
  174. if (length > longest) {
  175. longest = length;
  176. *pos = lpos;
  177. }
  178. length = 0;
  179. }
  180. }
  181. }
  182. return length;
  183. }
  184. /**************** HMMState ******************************/
  185. static void HMMState_init (HMMState me, conststring32 label) {
  186. my label = Melder_dup (label);
  187. }
  188. autoHMMState HMMState_create (conststring32 label) {
  189. try {
  190. autoHMMState me = Thing_new (HMMState);
  191. HMMState_init (me.get(), label);
  192. return me;
  193. } catch (MelderError) {
  194. Melder_throw (U"HMMState not created.");
  195. }
  196. }
  197. void HMMState_setLabel (HMMState me, char32 *label) {
  198. my label = Melder_dup (label);
  199. }
  200. /**************** HMMBaumWelch ******************************/
  201. void structHMMBaumWelch :: v_destroy () noexcept {
  202. for (integer it = 1; it <= capacity; it ++) {
  203. NUMmatrix_free (xi [it], 1, 1);
  204. }
  205. NUMvector_free (xi, 1);
  206. NUMvector_free (scale, 1);
  207. NUMmatrix_free (beta, 1, 1);
  208. NUMmatrix_free (alpha, 1, 1);
  209. NUMmatrix_free (gamma, 1, 1);
  210. NUMmatrix_free (aij_num, 0, 1);
  211. NUMmatrix_free (aij_denom, 0, 1);
  212. NUMmatrix_free (bik_num, 1, 1);
  213. NUMmatrix_free (bik_denom, 1, 1);
  214. }
  215. autoHMMBaumWelch HMMBaumWelch_create (integer nstates, integer nsymbols, integer capacity) {
  216. try {
  217. autoHMMBaumWelch me = Thing_new (HMMBaumWelch);
  218. my numberOfTimes = my capacity = capacity;
  219. my numberOfStates = nstates;
  220. my numberOfSymbols = nsymbols;
  221. my alpha = NUMmatrix<double> (1, nstates, 1, capacity);
  222. my beta = NUMmatrix<double> (1, nstates, 1, capacity);
  223. my scale = NUMvector<double> (1, capacity);
  224. my xi = NUMvector<double **> (1, capacity);
  225. my aij_num = NUMmatrix<double> (0, nstates, 1, nstates + 1);
  226. my aij_denom = NUMmatrix<double> (0, nstates, 1, nstates + 1);
  227. my bik_num = NUMmatrix<double> (1, nstates, 1, nsymbols);
  228. my bik_denom = NUMmatrix<double> (1, nstates, 1, nsymbols);
  229. my gamma = NUMmatrix<double> (1, nstates, 1, capacity);
  230. for (integer it = 1; it <= capacity; it ++)
  231. my xi [it] = NUMmatrix<double> (1, nstates, 1, nstates);
  232. return me;
  233. } catch (MelderError) {
  234. Melder_throw (U"HMMBaumWelch not created.");
  235. }
  236. }
  237. void HMMBaumWelch_getGamma (HMMBaumWelch me) {
  238. for (integer it = 1; it <= my numberOfTimes; it ++) {
  239. double sum = 0.0;
  240. for (integer is = 1; is <= my numberOfStates; is ++) {
  241. my gamma [is] [it] = my alpha [is] [it] * my beta [is] [it];
  242. sum += my gamma [is] [it];
  243. }
  244. for (integer is = 1; is <= my numberOfStates; is ++)
  245. my gamma [is] [it] /= sum;
  246. }
  247. }
  248. /**************** HMMViterbi ******************************/
  249. autoHMMViterbi HMMViterbi_create (integer nstates, integer ntimes) {
  250. try {
  251. autoHMMViterbi me = Thing_new (HMMViterbi);
  252. my numberOfTimes = ntimes;
  253. my numberOfStates = nstates;
  254. my viterbi = NUMmatrix <double> (1, nstates, 1, ntimes);
  255. my bp = NUMmatrix <integer> (1, nstates, 1 , ntimes);
  256. my path = NUMvector <integer> (1, ntimes);
  257. return me;
  258. } catch (MelderError) {
  259. Melder_throw (U"HMMViterbi not created.");
  260. }
  261. }
  262. /******************* HMMObservationSequence & HMMStateSequence ***/
  263. autoHMMObservationSequence HMMObservationSequence_create (integer numberOfItems, integer dataLength) {
  264. try {
  265. autoHMMObservationSequence me = Thing_new (HMMObservationSequence);
  266. Table_initWithoutColumnNames (me.get(), numberOfItems, dataLength + 1);
  267. return me;
  268. } catch (MelderError) {
  269. Melder_throw (U"HMMObservationSequence not created.");
  270. }
  271. }
  272. integer HMMObservationSequence_getNumberOfObservations (HMMObservationSequence me) {
  273. return my rows.size;
  274. }
  275. void HMMObservationSequence_removeObservation (HMMObservationSequence me, integer index) {
  276. Table_removeRow (me, index);
  277. }
  278. autoStrings HMMObservationSequence_to_Strings (HMMObservationSequence me) {
  279. try {
  280. integer numberOfStrings = my rows.size;
  281. autoStrings thee = Thing_new (Strings);
  282. thy strings = autostring32vector (numberOfStrings);
  283. for (integer i = 1; i <= numberOfStrings; i ++)
  284. thy strings [i] = Melder_dup (Table_getStringValue_Assert (me, i, 1));
  285. thy numberOfStrings = numberOfStrings;
  286. return thee;
  287. } catch (MelderError) {
  288. Melder_throw (me, U": no Strings created.");
  289. }
  290. }
  291. autoHMMObservationSequence Strings_to_HMMObservationSequence (Strings me) {
  292. try {
  293. autoHMMObservationSequence thee = HMMObservationSequence_create (my numberOfStrings, 0);
  294. for (integer i = 1; i <= my numberOfStrings; i ++)
  295. Table_setStringValue (thee.get(), i, 1, my strings [i].get());
  296. return thee;
  297. } catch (MelderError) {
  298. Melder_throw (me, U": no HMMObservationSequence created.");
  299. }
  300. }
  301. autoStringsIndex HMMObservationSequence_to_StringsIndex (HMMObservationSequence me) {
  302. try {
  303. autoStrings s = HMMObservationSequence_to_Strings (me);
  304. autoStringsIndex thee = Strings_to_StringsIndex (s.get());
  305. return thee;
  306. } catch (MelderError) {
  307. Melder_throw (me, U": no StringsIndex created.");
  308. }
  309. }
  310. integer HMM_HMMObservationSequence_getLongestSequence (HMM me, HMMObservationSequence thee, integer symbolNumber) {
  311. autoStringsIndex si = HMM_HMMObservationSequence_to_StringsIndex (me, thee);
  312. // TODO
  313. (void) symbolNumber;
  314. return 1;
  315. }
  316. integer HMMObservationSequenceBag_getLongestSequence (HMMObservationSequenceBag me) {
  317. integer longest = 0;
  318. for (integer i = 1; i <= my size; i ++) {
  319. HMMObservationSequence thee = my at [i];
  320. if (thy rows.size > longest)
  321. longest = thy rows.size;
  322. }
  323. return longest;
  324. }
  325. autoHMMStateSequence HMMStateSequence_create (integer numberOfItems) {
  326. try {
  327. autoHMMStateSequence me = Thing_new (HMMStateSequence);
  328. my strings = autostring32vector (numberOfItems);
  329. return me;
  330. } catch (MelderError) {
  331. Melder_throw (U"HMMStateSequence not created.");
  332. }
  333. }
  334. autoStrings HMMStateSequence_to_Strings (HMMStateSequence me) {
  335. try {
  336. autoStrings thee = Thing_new (Strings);
  337. my structStrings :: v_copy (thee.get());
  338. return thee;
  339. } catch (MelderError) {
  340. Melder_throw (me, U": no Strings created.");
  341. }
  342. }
  343. /**************** HMM ******************************/
  344. void structHMM :: v_info () {
  345. structDaata :: v_info ();
  346. MelderInfo_writeLine (U"Number of states: ", numberOfStates);
  347. for (integer i = 1; i <= numberOfStates; i ++) {
  348. HMMState hmms = our states->at [i];
  349. MelderInfo_writeLine (U" ", hmms -> label.get());
  350. }
  351. MelderInfo_writeLine (U"Number of symbols: ", numberOfObservationSymbols);
  352. for (integer i = 1; i <= numberOfObservationSymbols; i ++) {
  353. HMMObservation hmms = our observationSymbols->at [i];
  354. MelderInfo_writeLine (U" ", hmms -> label.get());
  355. }
  356. }
  357. static void HMM_init (HMM me, integer numberOfStates, integer numberOfObservationSymbols, int leftToRight) {
  358. my numberOfStates = numberOfStates;
  359. my numberOfObservationSymbols = numberOfObservationSymbols;
  360. my componentStorage = 1;
  361. my leftToRight = leftToRight;
  362. my states = HMMStateList_create ();
  363. my observationSymbols = HMMObservationList_create ();
  364. my transitionProbs = NUMmatrix<double> (0, numberOfStates, 1, numberOfStates + 1);
  365. my emissionProbs = NUMmatrix<double> (1, numberOfStates, 1, numberOfObservationSymbols);
  366. }
  367. autoHMM HMM_create (int leftToRight, integer numberOfStates, integer numberOfObservationSymbols) {
  368. try {
  369. autoHMM me = Thing_new (HMM);
  370. HMM_init (me.get(), numberOfStates, numberOfObservationSymbols, leftToRight);
  371. HMM_setDefaultStates (me.get());
  372. HMM_setDefaultObservations (me.get());
  373. HMM_setDefaultTransitionProbs (me.get());
  374. HMM_setDefaultStartProbs (me.get());
  375. HMM_setDefaultEmissionProbs (me.get());
  376. return me;
  377. } catch (MelderError) {
  378. Melder_throw (U"HMM not created.");
  379. }
  380. }
  381. void HMM_setDefaultStates (HMM me) {
  382. for (integer i = 1; i <= my numberOfStates; i ++) {
  383. autoHMMState hmms = HMMState_create (Melder_cat (U"S", i));
  384. HMM_addState_move (me, hmms.move());
  385. }
  386. }
  387. autoHMM HMM_createFullContinuousModel (int leftToRight, integer numberOfStates, integer numberOfObservationSymbols, integer numberOfFeatureStreams, integer *dimensionOfStream, integer *numberOfGaussiansforStream) {
  388. (void) leftToRight;
  389. (void) numberOfStates;
  390. (void) numberOfObservationSymbols;
  391. (void) numberOfFeatureStreams;
  392. (void) dimensionOfStream;
  393. (void) numberOfGaussiansforStream;
  394. return autoHMM();
  395. }
  396. autoHMM HMM_createContinuousModel (int leftToRight, integer numberOfStates, integer numberOfObservationSymbols, integer numberOfMixtureComponentsPerSymbol, integer componentDimension, integer componentStorage) {
  397. try {
  398. autoHMM me = Thing_new (HMM);
  399. HMM_init (me.get(), numberOfStates, numberOfObservationSymbols, leftToRight);
  400. my numberOfMixtureComponents = numberOfMixtureComponentsPerSymbol;
  401. my componentDimension = componentDimension;
  402. my componentStorage = componentStorage;
  403. for (integer i = 1; i <= numberOfStates; i ++) {
  404. autoHMMState state = HMMState_create (Melder_cat (U"S", i));
  405. HMM_addState_move (me.get(), state.move());
  406. }
  407. for (integer j = 1; j <= numberOfObservationSymbols; j ++) {
  408. autoHMMObservation obs = HMMObservation_create (Melder_cat (U"s", j), numberOfMixtureComponentsPerSymbol, componentDimension, componentStorage);
  409. HMM_addObservation_move (me.get(), obs.move());
  410. }
  411. HMM_setDefaultTransitionProbs (me.get());
  412. HMM_setDefaultStartProbs (me.get());
  413. HMM_setDefaultEmissionProbs (me.get());
  414. HMM_setDefaultMixingProbabilities (me.get());
  415. return me;
  416. } catch (MelderError) {
  417. Melder_throw (U"Continuous model HMM not created.");
  418. }
  419. }
  420. // for a simple non-hidden model leave either states empty or symbols empty !!!
  421. autoHMM HMM_createSimple (int leftToRight, conststring32 states_string, conststring32 symbols_string) {
  422. try {
  423. autostring32vector states = STRVECtokenize (states_string);
  424. autostring32vector symbols = STRVECtokenize (symbols_string);
  425. autoHMM me = Thing_new (HMM);
  426. Melder_require (states.size > 0 || symbols.size > 0,
  427. U"The states and symbols should not both be empty.");
  428. if (symbols.size <= 0) {
  429. symbols = STRVECclone (states.get());
  430. my notHidden = 1;
  431. }
  432. if (states.size <= 0) {
  433. states = STRVECclone (symbols.get());
  434. my notHidden = 1;
  435. }
  436. HMM_init (me.get(), states.size, symbols.size, leftToRight);
  437. for (integer istate = 1; istate <= states.size; istate ++) {
  438. autoHMMState state = HMMState_create (states [istate].get());
  439. HMM_addState_move (me.get(), state.move());
  440. }
  441. for (integer isymbol = 1; isymbol <= symbols.size; isymbol ++) {
  442. autoHMMObservation symbol = HMMObservation_create (symbols [isymbol].get(), 0, 0, 0);
  443. HMM_addObservation_move (me.get(), symbol.move());
  444. }
  445. HMM_setDefaultTransitionProbs (me.get());
  446. HMM_setDefaultStartProbs (me.get());
  447. HMM_setDefaultEmissionProbs (me.get());
  448. return me;
  449. } catch (MelderError) {
  450. Melder_throw (U"Simple HMM not created.");
  451. }
  452. }
  453. void HMM_setDefaultObservations (HMM me) {
  454. conststring32 def = my notHidden ? U"S" : U"s";
  455. for (integer i = 1; i <= my numberOfObservationSymbols; i ++) {
  456. autoHMMObservation hmms = HMMObservation_create (Melder_cat (def, i), 0, 0, 0);
  457. HMM_addObservation_move (me, hmms.move());
  458. }
  459. }
  460. void HMM_setDefaultTransitionProbs (HMM me) {
  461. for (integer i = 1; i <= my numberOfStates; i ++) {
  462. double p = ( my leftToRight ? 1.0 / (my numberOfStates - i + 1.0) : 1.0 / my numberOfStates );
  463. for (integer j = 1; j <= my numberOfStates; j ++)
  464. my transitionProbs [i] [j] = ( my leftToRight && j < i ? 0.0 : p );
  465. }
  466. // leftToRight must have end state!
  467. if (my leftToRight) my transitionProbs [my numberOfStates] [my numberOfStates] =
  468. my transitionProbs [my numberOfStates] [my numberOfStates + 1] = 0.5;
  469. }
  470. void HMM_setDefaultStartProbs (HMM me) {
  471. double p = 1.0 / my numberOfStates;
  472. for (integer j = 1; j <= my numberOfStates; j ++)
  473. my transitionProbs [0] [j] = p;
  474. }
  475. void HMM_setDefaultEmissionProbs (HMM me) {
  476. double p = 1.0 / my numberOfObservationSymbols;
  477. for (integer i = 1; i <= my numberOfStates; i ++) {
  478. for (integer j = 1; j <= my numberOfObservationSymbols; j ++)
  479. my emissionProbs [i] [j] = ( my notHidden ? ( i == j ? 1.0 : 0.0 ) : p );
  480. }
  481. }
  482. void HMM_setDefaultMixingProbabilities (HMM me) {
  483. double mp = 1.0 / my numberOfMixtureComponents;
  484. for (integer is = 1; is <= my numberOfObservationSymbols; is ++) {
  485. HMMObservation hmmo = my observationSymbols->at [is];
  486. for (integer im = 1; im <= my numberOfMixtureComponents; im ++)
  487. hmmo -> gm -> mixingProbabilities [im] = mp;
  488. }
  489. }
  490. void HMM_setStartProbabilities (HMM me, conststring32 probs) {
  491. try {
  492. autoVEC p = NUMwstring_to_probs (probs, my numberOfStates);
  493. for (integer i = 1; i <= my numberOfStates; i ++)
  494. my transitionProbs [0] [i] = p [i];
  495. } catch (MelderError) {
  496. Melder_throw (me, U": no start probabilities set.");
  497. }
  498. }
  499. void HMM_setTransitionProbabilities (HMM me, integer state_number, conststring32 state_probs) {
  500. try {
  501. Melder_require (state_number <= my states->size,
  502. U"State number should not exceed ", my states->size, U".");
  503. autoVEC p = NUMwstring_to_probs (state_probs, my numberOfStates);
  504. for (integer i = 1; i <= my numberOfStates + 1; i ++)
  505. my transitionProbs [state_number] [i] = p [i];
  506. } catch (MelderError) {
  507. Melder_throw (me, U": no transition probabilities set.");
  508. }
  509. }
  510. void HMM_setEmissionProbabilities (HMM me, integer state_number, conststring32 emission_probs) {
  511. try {
  512. Melder_require (state_number <= my states->size,
  513. U"State number should not exceed ", my states->size, U".");
  514. Melder_require (! my notHidden,
  515. U"The emission probabilities of this model are fixed.");
  516. autoVEC p = NUMwstring_to_probs (emission_probs, my numberOfObservationSymbols);
  517. for (integer i = 1; i <= my numberOfObservationSymbols; i ++) {
  518. my emissionProbs [state_number] [i] = p [i];
  519. }
  520. } catch (MelderError) {
  521. Melder_throw (me, U": no emission probabilities set.");
  522. }
  523. }
  524. void HMM_addObservation_move (HMM me, autoHMMObservation thee) {
  525. integer ns = my observationSymbols->size + 1;
  526. Melder_require (ns <= my numberOfObservationSymbols, U"Observation list is full.");
  527. my observationSymbols -> addItemAtPosition_move (thee.move(), ns);
  528. }
  529. void HMM_addState_move (HMM me, autoHMMState thee) {
  530. integer ns = my states->size + 1;
  531. Melder_require (ns <= my numberOfStates, U"States list is full.");
  532. my states -> addItemAtPosition_move (thee.move(), ns);
  533. }
  534. autoTableOfReal HMM_extractTransitionProbabilities (HMM me) {
  535. try {
  536. autoTableOfReal thee = TableOfReal_create (my numberOfStates + 1, my numberOfStates + 1);
  537. for (integer is = 1; is <= my numberOfStates; is ++) {
  538. HMMState hmms = my states->at [is];
  539. TableOfReal_setRowLabel (thee.get(), is + 1, hmms -> label.get());
  540. TableOfReal_setColumnLabel (thee.get(), is, hmms -> label.get());
  541. for (integer js = 1; js <= my numberOfStates; js ++)
  542. thy data [is + 1] [js] = my transitionProbs [is] [js];
  543. }
  544. TableOfReal_setRowLabel (thee.get(), 1, U"START");
  545. TableOfReal_setColumnLabel (thee.get(), my numberOfStates + 1, U"END");
  546. for (integer is = 1; is <= my numberOfStates; is ++) {
  547. thy data [1] [is] = my transitionProbs [0] [is];
  548. thy data [is + 1] [my numberOfStates + 1] = my transitionProbs [is] [my numberOfStates + 1];
  549. }
  550. return thee;
  551. } catch (MelderError) {
  552. Melder_throw (me, U": no transition probabilities extracted.");
  553. }
  554. }
  555. autoTableOfReal HMM_extractEmissionProbabilities (HMM me) {
  556. try {
  557. autoTableOfReal thee = TableOfReal_create (my numberOfStates, my numberOfObservationSymbols);
  558. for (integer js = 1; js <= my numberOfObservationSymbols; js ++) {
  559. HMMObservation hmms = my observationSymbols->at [js];
  560. TableOfReal_setColumnLabel (thee.get(), js, hmms -> label.get());
  561. }
  562. for (integer is = 1; is <= my numberOfStates; is ++) {
  563. HMMState hmms = my states->at [is];
  564. TableOfReal_setRowLabel (thee.get(), is, hmms -> label.get());
  565. for (integer js = 1; js <= my numberOfObservationSymbols; js ++)
  566. thy data [is] [js] = my emissionProbs [is] [js];
  567. }
  568. return thee;
  569. } catch (MelderError) {
  570. Melder_throw (me, U": no emission probabilities extracted.");
  571. };
  572. }
  573. double HMM_getExpectedValueOfDurationInState (HMM me, integer istate) {
  574. if (istate < 0 || istate > my numberOfStates)
  575. return undefined;
  576. return 1.0 / (1.0 - my transitionProbs [istate] [istate]);
  577. }
  578. double HMM_getProbabilityOfStayingInState (HMM me, integer istate, integer numberOfTimeUnits) {
  579. if (istate < 0 || istate > my numberOfStates)
  580. return undefined;
  581. return pow (my transitionProbs [istate] [istate], numberOfTimeUnits - 1.0) * (1.0 - my transitionProbs [istate] [istate]);
  582. }
  583. double HMM_HMM_getCrossEntropy (HMM me, HMM thee, integer observationLength, int symmetric) {
  584. double ce1 = HMM_HMM_getCrossEntropy_asym (me, thee, observationLength);
  585. if (! symmetric || isundef (ce1))
  586. return ce1;
  587. double ce2 = HMM_HMM_getCrossEntropy_asym (thee, me, observationLength);
  588. if (isundef (ce2))
  589. return ce2;
  590. return (ce1 + ce2) / 2.0;
  591. }
  592. double HMM_HMM_HMMObservationSequence_getCrossEntropy (HMM me, HMM thee, HMMObservationSequence him) {
  593. double ce1 = HMM_HMMObservationSequence_getCrossEntropy (me, him);
  594. if (isundef (ce1))
  595. return ce1;
  596. double ce2 = HMM_HMMObservationSequence_getCrossEntropy (thee, him);
  597. if (isundef (ce2))
  598. return ce2;
  599. return (ce1 + ce2) / 2.0;
  600. }
  601. void HMM_draw (HMM me, Graphics g, int garnish) {
  602. double xwidth = sqrt (my numberOfStates);
  603. double rstate = 0.3 / xwidth, r = xwidth / 3.0;
  604. double xmax = 1.2 * xwidth / 2.0, xmin = -xmax, ymin = xmin, ymax = xmax;
  605. autoNUMvector<double> xs (1, my numberOfStates);
  606. autoNUMvector<double> ys (1, my numberOfStates);
  607. Graphics_setWindow (g, xmin, xmax, ymin, ymax);
  608. // heuristic: all states on a circle until we have a better graph drawing algorithm.
  609. xs [1] = ys [1] = 0;
  610. if (my numberOfStates > 1) {
  611. for (integer is = 1; is <= my numberOfStates; is ++) {
  612. double alpha = - NUMpi + NUMpi * 2.0 * (is - 1) / my numberOfStates;
  613. xs [is] = r * cos (alpha); ys [is] = r * sin (alpha);
  614. }
  615. }
  616. // reorder the positions such that state number 1 is most left and last state number is right.
  617. // if > 5 may be one in the middle with the most connections
  618. // ...
  619. // find fontsize
  620. int fontSize = Graphics_inqFontSize (g);
  621. conststring32 widest_label = U"";
  622. double max_width = 0.0;
  623. for (integer is = 1; is <= my numberOfStates; is ++) {
  624. HMMState hmms = my states->at [is];
  625. double w = ( hmms -> label ? Graphics_textWidth (g, hmms -> label.get()) : 0.0 );
  626. if (w > max_width) {
  627. widest_label = hmms -> label.get();
  628. max_width = w;
  629. }
  630. }
  631. int new_fontSize = fontSize;
  632. while (max_width > 2.0 * rstate && new_fontSize > 4) {
  633. new_fontSize --;
  634. Graphics_setFontSize (g, new_fontSize);
  635. max_width = Graphics_textWidth (g, widest_label);
  636. }
  637. Graphics_setFontSize (g, new_fontSize);
  638. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  639. for (integer is = 1; is <= my numberOfStates; is ++) {
  640. HMMState hmms = my states->at [is];
  641. Graphics_circle (g, xs [is], ys [is], rstate);
  642. Graphics_text (g, xs [is], ys [is], hmms -> label.get());
  643. }
  644. // draw connections from is to js
  645. // 1 -> 2 / 2-> : increase / decrease angle between 1 and 2 with pi /10
  646. // use cos(a+b) and cos(a -b) rules
  647. double cosb = cos (NUMpi / 10.0), sinb = sin (NUMpi / 10.0);
  648. for (integer is = 1; is <= my numberOfStates; is ++) {
  649. double x1 = xs [is], y1 = ys [is];
  650. for (integer js = 1; js <= my numberOfStates; js ++) {
  651. if (my transitionProbs [is] [js] > 0.0 && is != js) {
  652. double x2 = xs [js], y2 = ys [js];
  653. double dx = x2 - x1, dy = y2 - y1, h = sqrt (dx * dx + dy * dy), cosa = dx / h, sina = dy / h;
  654. double cosabp = cosa * cosb - sina * sinb, cosabm = cosa * cosb + sina * sinb;
  655. double sinabp = cosa * sinb + sina * cosb, sinabm = -cosa * sinb + sina * cosb;
  656. Graphics_arrow (g, x1 + rstate * cosabp, y1 + rstate * sinabp, x2 - rstate * cosabm, y2 - rstate * sinabm);
  657. }
  658. if (is == js) {
  659. double dx = - x1, dy = - y1, h = sqrt (dx * dx + dy * dy), cosa = dx / h, sina = dy / h;
  660. Graphics_doubleArrow (g, x1 - rstate * cosa, y1 - rstate * sina, x1 - 1.4 * rstate * cosa, y1 - 1.4 * rstate * sina);
  661. }
  662. }
  663. }
  664. if (garnish) {
  665. Graphics_drawInnerBox (g);
  666. }
  667. }
  668. void HMM_unExpandPCA (HMM me) {
  669. if (my componentDimension <= 0)
  670. return; // nothing to do
  671. for (integer is = 1; is <= my numberOfObservationSymbols; is ++) {
  672. HMMObservation s = my observationSymbols->at [is];
  673. GaussianMixture_unExpandPCA (s -> gm.get());
  674. }
  675. }
  676. autoHMMObservationSequence HMM_to_HMMObservationSequence (HMM me, integer startState, integer numberOfItems) {
  677. try {
  678. autoHMMObservationSequence thee = HMMObservationSequence_create (numberOfItems, my componentDimension);
  679. integer istate = startState == 0 ? NUMgetIndexFromProbability (my transitionProbs [0], my numberOfStates, NUMrandomUniform (0.0, 1.0)) : startState;
  680. if (my componentDimension > 0) {
  681. autoVEC obs (my componentDimension, kTensorInitializationType::RAW);
  682. autoVEC buf (my componentDimension, kTensorInitializationType::RAW);
  683. for (integer i = 1; i <= numberOfItems; i ++) {
  684. // Emit a symbol from istate
  685. integer isymbol = NUMgetIndexFromProbability (my emissionProbs [istate], my numberOfObservationSymbols, NUMrandomUniform (0.0, 1.0));
  686. HMMObservation s = my observationSymbols->at [isymbol];
  687. char32 *name;
  688. GaussianMixture_generateOneVector_inline (s -> gm.get(), obs.get(), &name, buf.get());
  689. for (integer j = 1; j <= my componentDimension; j ++)
  690. Table_setNumericValue (thee.get(), i, 1 + j, obs [j]);
  691. Table_setStringValue (thee.get(), i, 1, s -> label.get());
  692. // get next state
  693. istate = NUMgetIndexFromProbability (my transitionProbs [istate], my numberOfStates + 1, NUMrandomUniform (0.0, 1.0));
  694. if (istate == my numberOfStates + 1) { // final state
  695. for (integer j = numberOfItems; j > i; j --)
  696. HMMObservationSequence_removeObservation (thee.get(), j);
  697. break;
  698. }
  699. }
  700. } else {
  701. for (integer i = 1; i <= numberOfItems; i ++) {
  702. // Emit a symbol from istate
  703. integer isymbol = NUMgetIndexFromProbability (my emissionProbs [istate], my numberOfObservationSymbols, NUMrandomUniform (0.0, 1.0));
  704. HMMObservation s = my observationSymbols->at [isymbol];
  705. Table_setStringValue (thee.get(), i, 1, s -> label.get());
  706. // get next state
  707. istate = NUMgetIndexFromProbability (my transitionProbs [istate], my numberOfStates + 1, NUMrandomUniform (0.0, 1.0));
  708. if (istate == my numberOfStates + 1) { // final state
  709. for (integer j = numberOfItems; j > i; j --)
  710. HMMObservationSequence_removeObservation (thee.get(), j);
  711. break;
  712. }
  713. }
  714. }
  715. HMM_unExpandPCA (me);
  716. return thee;
  717. } catch (MelderError) {
  718. HMM_unExpandPCA (me);
  719. Melder_throw (me, U":no HMMObservationSequence created.");
  720. }
  721. }
  722. autoHMMBaumWelch HMM_forward (HMM me, integer *obs, integer nt) {
  723. try {
  724. autoHMMBaumWelch thee = HMMBaumWelch_create (my numberOfStates, my numberOfObservationSymbols, nt);
  725. HMM_HMMBaumWelch_forward (me, thee.get(), obs);
  726. return thee;
  727. } catch (MelderError) {
  728. Melder_throw (me, U": no HMMBaumWelch created.");
  729. }
  730. }
  731. autoHMMViterbi HMM_to_HMMViterbi (HMM me, integer *obs, integer ntimes) {
  732. try {
  733. autoHMMViterbi thee = HMMViterbi_create (my numberOfStates, ntimes);
  734. HMM_HMMViterbi_decode (me, thee.get(), obs);
  735. return thee;
  736. } catch (MelderError) {
  737. Melder_throw (me, U": no HMMViterbi created.");
  738. }
  739. }
  740. void HMMBaumWelch_reInit (HMMBaumWelch me) {
  741. my totalNumberOfSequences = 0;
  742. my lnProb = 0.0;
  743. /*
  744. The _num and _denum matrices are asigned as += in the iteration loop and therefore need to be zeroed
  745. at the start of each new iteration.
  746. The elements of alpha, beta, scale, gamma & xi are always calculated directly and need not be
  747. initialised.
  748. */
  749. for (integer is = 0; is <= my numberOfStates; is ++) {
  750. for (integer js = 1; js <= my numberOfStates + 1; js ++) {
  751. my aij_num [is] [js] = 0.0;
  752. my aij_denom [is] [js] = 0.0;
  753. }
  754. }
  755. for (integer is = 1; is <= my numberOfStates; is ++) {
  756. for (integer js = 1; js <= my numberOfSymbols; js ++) {
  757. my bik_num [is] [js] = 0.0;
  758. my bik_denom [is] [js] = 0.0;
  759. }
  760. }
  761. }
  762. void HMM_HMMObservationSequenceBag_learn (HMM me, HMMObservationSequenceBag thee, double delta_lnp, double minProb, int info) {
  763. try {
  764. // act as if all observation sequences are in memory
  765. integer capacity = HMMObservationSequenceBag_getLongestSequence (thee);
  766. autoHMMBaumWelch bw = HMMBaumWelch_create (my numberOfStates, my numberOfObservationSymbols, capacity);
  767. bw -> minProb = minProb;
  768. if (info) {
  769. MelderInfo_open ();
  770. }
  771. integer iter = 0;
  772. double lnp;
  773. do {
  774. lnp = bw -> lnProb;
  775. HMMBaumWelch_reInit (bw.get());
  776. for (integer ios = 1; ios <= thy size; ios ++) {
  777. HMMObservationSequence hmm_os = thy at [ios];
  778. autoStringsIndex si = HMM_HMMObservationSequence_to_StringsIndex (me, hmm_os); // TODO outside the loop or more efficiently
  779. integer *obs = si -> classIndex.at;
  780. integer nobs = si -> numberOfItems; // convenience
  781. // Interpretation of unknowns: end of sequence
  782. integer istart = 1, iend = nobs;
  783. while (istart <= nobs) {
  784. while (istart <= nobs && obs [istart] == 0) {
  785. istart ++;
  786. };
  787. if (istart > nobs) {
  788. break;
  789. }
  790. iend = istart + 1;
  791. while (iend <= nobs && obs [iend] != 0) {
  792. iend ++;
  793. }
  794. iend --;
  795. bw -> numberOfTimes = iend - istart + 1;
  796. bw -> totalNumberOfSequences ++;
  797. HMM_HMMBaumWelch_forward (me, bw.get(), obs + istart - 1); // get new alphas
  798. HMM_HMMBaumWelch_backward (me, bw.get(), obs + istart - 1); // get new betas
  799. HMMBaumWelch_getGamma (bw.get());
  800. HMM_HMMBaumWelch_getXi (me, bw.get(), obs + istart - 1);
  801. HMM_HMMBaumWelch_addEstimate (me, bw.get(), obs + istart - 1);
  802. istart = iend + 1;
  803. }
  804. }
  805. // we have processed all observation sequences, now it is time to estimate new probabilities.
  806. iter ++;
  807. HMM_HMMBaumWelch_reestimate (me, bw.get());
  808. if (info) {
  809. MelderInfo_writeLine (U"Iteration: ", iter, U" ln(prob): ", bw -> lnProb);
  810. }
  811. } while (fabs ((lnp - bw -> lnProb) / bw -> lnProb) > delta_lnp);
  812. if (info) {
  813. MelderInfo_writeLine (U"******** Learning summary *********");
  814. MelderInfo_writeLine (U" Processed ", thy size, U" sequences,");
  815. MelderInfo_writeLine (U" consisting of ", bw -> totalNumberOfSequences, U" observation sequences.");
  816. MelderInfo_writeLine (U" Longest observation sequence had ", capacity, U" items");
  817. MelderInfo_close();
  818. }
  819. } catch (MelderError) {
  820. Melder_throw (me, U" & ", thee, U": not learned.");
  821. }
  822. }
  823. // xc1 < xc2
  824. void HMM_HMMStateSequence_drawTrellis (HMM me, HMMStateSequence thee, Graphics g, int connect, int garnish) {
  825. integer numberOfTimes = thy numberOfStrings;
  826. autoStringsIndex si = HMM_HMMStateSequence_to_StringsIndex (me, thee);
  827. double xmin = 0.0, xmax = numberOfTimes + 1.0, ymin = 0.5, ymax = my numberOfStates + 0.5;
  828. Graphics_setInner (g);
  829. Graphics_setWindow (g, xmin, xmax, ymin, ymax);
  830. double r = 0.2 / (numberOfTimes > my numberOfStates ? numberOfTimes : my numberOfStates);
  831. for (integer it = 1; it <= numberOfTimes; it ++) {
  832. for (integer js = 1; js <= my numberOfStates; js ++) {
  833. double xc = it, yc = js, x2 = it, y2 = js;
  834. Graphics_circle (g, xc, yc, r);
  835. if (it > 1) {
  836. for (integer is = 1; is <= my numberOfStates; is ++) {
  837. bool indexedConnection = si -> classIndex [it - 1] == is && si -> classIndex [it] == js;
  838. Graphics_setLineWidth (g, indexedConnection ? 2.0 : 1.0);
  839. Graphics_setLineType (g, indexedConnection ? Graphics_DRAWN : Graphics_DOTTED);
  840. double x1 = it - 1, y1 = is;
  841. if (connect || indexedConnection) {
  842. double a = (y1 - y2) / (x1 - x2), b = y1 - a * x1;
  843. // double xs11 = x1 - r / (a * a + 1), ys11 = a * xs11 + b;
  844. double xs12 = x1 + r / (a * a + 1), ys12 = a * xs12 + b;
  845. double xs21 = x2 - r / (a * a + 1), ys21 = a * xs21 + b;
  846. // double xs22 = x2 + r / (a * a + 1), ys22 = a * xs22 + b;
  847. Graphics_line (g, xs12, ys12, xs21, ys21);
  848. }
  849. }
  850. }
  851. }
  852. }
  853. Graphics_unsetInner (g);
  854. Graphics_setLineWidth (g, 1.0);
  855. Graphics_setLineType (g, Graphics_DRAWN);
  856. if (garnish) {
  857. Graphics_drawInnerBox (g);
  858. for (integer js = 1; js <= my numberOfStates; js ++) {
  859. HMMState hmms = my states->at [js];
  860. Graphics_markLeft (g, js, false, false, false, hmms -> label.get());
  861. }
  862. Graphics_marksBottomEvery (g, 1.0, 1.0, true, true, false);
  863. Graphics_textBottom (g, true, U"Time index");
  864. }
  865. }
  866. void HMM_drawBackwardProbabilitiesIllustration (Graphics g, bool garnish) {
  867. double xmin = 0.0, xmax = 1.0, ymin = 0.0, ymax = 1.0;
  868. Graphics_setWindow (g, xmin, xmax, ymin, ymax);
  869. double xleft = 0.1, xright = 0.9, r = 0.03;
  870. integer np = 6;
  871. double dy = (1.0 - 0.3) / (np - 1);
  872. double x0 = xleft, y0 = 0.5;
  873. Graphics_setWindow (g, xmin, xmax, ymin, ymax);
  874. Graphics_circle (g, x0, y0, r);
  875. double x = xright, y = 0.9;
  876. for (integer i = 1; i <= np; i ++) {
  877. if (i < 4 or i == np) {
  878. Graphics_circle (g, x, y, r);
  879. double xx = x0 - x, yy = y - y0;
  880. double c = sqrt (xx * xx + yy * yy);
  881. double cosa = xx / c, sina = yy / c;
  882. Graphics_line (g, x0 - r * cosa, y0 + r * sina, x + r * cosa, y - r * sina);
  883. } else if (i == 4) {
  884. double ddy = 3*dy/4;
  885. Graphics_fillCircle (g, x, y + dy - ddy, 0.5 * r);
  886. Graphics_fillCircle (g, x, y + dy - 2 * ddy, 0.5 * r);
  887. Graphics_fillCircle (g, x, y + dy - 3 * ddy, 0.5 * r);
  888. }
  889. y -= dy;
  890. }
  891. if (garnish) {
  892. double x1 = xright + 1.5 * r, x2 = x1 - 0.2, y1 = 0.9;
  893. Graphics_setTextAlignment (g, Graphics_LEFT, Graphics_HALF);
  894. Graphics_text (g, x1, y1, U"%s__1_");
  895. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  896. Graphics_text (g, x2, y1, U"%a__%i1_");
  897. y1 = 0.9 - dy;
  898. Graphics_setTextAlignment (g, Graphics_LEFT, Graphics_HALF);
  899. Graphics_text (g, x1, y1, U"%s__2_");
  900. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  901. Graphics_text (g, x2, y1, U"%a__%i2_");
  902. y1 = 0.9 - (np - 1) * dy;
  903. Graphics_setTextAlignment (g, Graphics_LEFT, Graphics_HALF);
  904. Graphics_text (g, x1, y1, U"%s__%N_");
  905. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  906. Graphics_text (g, x2, y1, U"%a__%%iN%_");
  907. Graphics_setTextAlignment (g, Graphics_RIGHT, Graphics_HALF);
  908. Graphics_text (g, x0 - 1.5 * r, y0, U"%s__%i_");
  909. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_BOTTOM);
  910. Graphics_text (g, x0, 0.0, U"%t");
  911. Graphics_text (g, x, 0.0, U"%t+1");
  912. double y3 = 0.10;
  913. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  914. Graphics_text (g, x0, y3, U"%\\be__%t_(%i)%");
  915. Graphics_text (g, x, y3, U"%\\be__%t+1_(%j)");
  916. }
  917. }
  918. void HMM_drawForwardProbabilitiesIllustration (Graphics g, bool garnish) {
  919. double xmin = 0.0, xmax = 1.0, ymin = 0.0, ymax = 1.0;
  920. Graphics_setWindow (g, xmin, xmax, ymin, ymax);
  921. double xleft = 0.1, xright = 0.9, r = 0.03;
  922. integer np = 6;
  923. double dy = (1.0 - 0.3) / (np - 1);
  924. double x0 = xright, y0 = 0.5;
  925. Graphics_setWindow (g, xmin, xmax, ymin, ymax);
  926. Graphics_circle (g, x0, y0, r);
  927. double x = xleft, y = 0.9;
  928. for (integer i = 1; i <= np; i ++) {
  929. if (i < 4 or i == np) {
  930. Graphics_circle (g, x, y, r);
  931. double xx = x0 - x, yy = y - y0;
  932. double c = sqrt (xx * xx + yy * yy);
  933. double cosa = xx / c, sina = yy / c;
  934. Graphics_line (g, x0 - r * cosa, y0 + r * sina, x + r * cosa, y - r * sina);
  935. } else if (i == 4) {
  936. double ddy = 3.0 * dy / 4.0;
  937. Graphics_fillCircle (g, x, y + dy - ddy, 0.5 * r);
  938. Graphics_fillCircle (g, x, y + dy - 2 * ddy, 0.5 * r);
  939. Graphics_fillCircle (g, x, y + dy - 3 * ddy, 0.5 * r);
  940. }
  941. y -= dy;
  942. }
  943. if (garnish) {
  944. double x1 = xleft - 1.5 * r, x2 = x1 + 0.2, y1 = 0.9;
  945. Graphics_setTextAlignment (g, Graphics_RIGHT, Graphics_HALF);
  946. Graphics_text (g, x1, y1, U"%s__1_");
  947. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  948. Graphics_text (g, x2, y1, U"%a__1%j_");
  949. y1 = 0.9 - dy;
  950. Graphics_setTextAlignment (g, Graphics_RIGHT, Graphics_HALF);
  951. Graphics_text (g, x1, y1, U"%s__2_");
  952. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  953. Graphics_text (g, x2, y1, U"%a__2%j_");
  954. y1 = 0.9 - (np - 1) * dy;
  955. Graphics_setTextAlignment (g, Graphics_RIGHT, Graphics_HALF);
  956. Graphics_text (g, x1, y1, U"%s__%N_");
  957. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  958. Graphics_text (g, x2, y1, U"%a__%%Nj%_");
  959. Graphics_setTextAlignment (g, Graphics_LEFT, Graphics_HALF);
  960. Graphics_text (g, x0 + 1.5 * r, y0, U"%s__%j_");
  961. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_BOTTOM);
  962. Graphics_text (g, x, 0.0, U"%t");
  963. Graphics_text (g, x0, 0.0, U"%t+1");
  964. double y3 = 0.10;
  965. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_HALF);
  966. Graphics_text (g, x, y3, U"%\\al__%t_(%i)%");
  967. Graphics_text (g, x0, y3, U"%\\al__%t+1_(%j)");
  968. }
  969. }
  970. void HMM_drawForwardAndBackwardProbabilitiesIllustration (Graphics g, bool garnish) {
  971. double xfrac = 0.1, xs = 1.0 / (0.5 - xfrac), r = 0.03;
  972. Graphics_Viewport vp = Graphics_insetViewport (g, 0.0, 0.5-xfrac, 0.0, 1.0);
  973. HMM_drawForwardProbabilitiesIllustration (g, false);
  974. Graphics_resetViewport (g, vp);
  975. Graphics_insetViewport (g, 0.5 + xfrac, 1.0, 0.0, 1.0);
  976. HMM_drawBackwardProbabilitiesIllustration (g, false);
  977. Graphics_resetViewport (g, vp);
  978. Graphics_setWindow (g, 0.0, xs, 0.0, 1.0);
  979. if (garnish) {
  980. double rx1 = 1.0 + xs * 2.0 * xfrac + 0.1, rx2 = rx1 + 0.9 - 0.1, y1 = 0.1;
  981. Graphics_line (g, 0.9 + r, 0.5, rx1 - r, 0.5);
  982. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_BOTTOM);
  983. Graphics_text (g, 0.9, 0.5 + r, U"%s__%i_");
  984. Graphics_text (g, rx1, 0.5 + r, U"%s__%j_");
  985. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_TOP);
  986. Graphics_text (g, 1.0 + xfrac * xs, 0.5, U"%a__%%ij%_%b__%j_(O__%t+1_)");
  987. Graphics_setTextAlignment (g, Graphics_CENTRE, Graphics_BOTTOM);
  988. Graphics_text (g, 0.1, 0.0, U"%t-1");
  989. Graphics_text (g, 0.9, 0.0, U"%t");
  990. Graphics_text (g, rx1, 0.0, U"%t+1");
  991. Graphics_text (g, rx2, 0.0, U"%t+2");
  992. Graphics_setLineType (g, Graphics_DASHED);
  993. double x4 = rx1 - 0.06, x3 = 0.9 + 0.06;
  994. Graphics_line (g, x3, 0.7, x3, 0.0);
  995. Graphics_line (g, x4, 0.7, x4, 0.0);
  996. Graphics_setLineType (g, Graphics_DRAWN);
  997. Graphics_arrow (g, x4, y1, x4 + 0.2, y1);
  998. Graphics_arrow (g, x3, y1, x3 - 0.2, y1);
  999. Graphics_setTextAlignment (g, Graphics_RIGHT, Graphics_BOTTOM);
  1000. Graphics_text (g, x3 - 0.01, y1, U"\\al__%t_(i)");
  1001. Graphics_setTextAlignment (g, Graphics_LEFT, Graphics_BOTTOM);
  1002. Graphics_text (g, x4 + 0.01, y1, U"\\be__%t+1_(j)");
  1003. }
  1004. }
  1005. void HMM_HMMBaumWelch_getXi (HMM me, HMMBaumWelch thee, integer *obs) {
  1006. for (integer it = 1; it <= thy numberOfTimes - 1; it ++) {
  1007. double sum = 0.0;
  1008. for (integer is = 1; is <= thy numberOfStates; is ++) {
  1009. for (integer js = 1; js <= thy numberOfStates; js ++) {
  1010. thy xi [it] [is] [js] = thy alpha [is] [it] * thy beta [js] [it + 1] *
  1011. my transitionProbs [is] [js] * my emissionProbs [js] [ obs [it + 1] ];
  1012. sum += thy xi [it] [is] [js];
  1013. }
  1014. }
  1015. for (integer is = 1; is <= my numberOfStates; is ++) {
  1016. for (integer js = 1; js <= my numberOfStates; js ++) {
  1017. thy xi [it] [is] [js] /= sum;
  1018. }
  1019. }
  1020. }
  1021. }
  1022. void HMM_HMMBaumWelch_addEstimate (HMM me, HMMBaumWelch thee, integer *obs) {
  1023. for (integer is = 1; is <= my numberOfStates; is ++) {
  1024. // only for valid start states with p > 0
  1025. if (my transitionProbs [0] [is] > 0.0) {
  1026. thy aij_num [0] [is] += thy gamma [is] [1];
  1027. thy aij_denom [0] [is] += 1.0;
  1028. }
  1029. }
  1030. for (integer is = 1; is <= my numberOfStates; is ++) {
  1031. double gammasum = 0.0;
  1032. for (integer it = 1; it <= thy numberOfTimes - 1; it ++) {
  1033. gammasum += thy gamma [is] [it];
  1034. }
  1035. for (integer js = 1; js <= my numberOfStates; js ++) {
  1036. double xisum = 0.0;
  1037. for (integer it = 1; it <= thy numberOfTimes - 1; it ++) {
  1038. xisum += thy xi [it] [is] [js];
  1039. }
  1040. // zero probs signal invalid connections, don't reestimate
  1041. if (my transitionProbs [is] [js] > 0.0) {
  1042. thy aij_num [is] [js] += xisum;
  1043. thy aij_denom [is] [js] += gammasum;
  1044. }
  1045. }
  1046. /*
  1047. Only reestimate the emissionProbs for a hidden markov model.
  1048. A not hidden model is emulated with fixed emissionProbs.
  1049. */
  1050. if (! my notHidden) {
  1051. gammasum += thy gamma [is] [thy numberOfTimes]; // now sum all, add last term
  1052. for (integer k = 1; k <= my numberOfObservationSymbols; k ++) {
  1053. double gammasum_k = 0.0;
  1054. for (integer it = 1; it <= thy numberOfTimes; it ++) {
  1055. if (obs [it] == k) {
  1056. gammasum_k += thy gamma [is] [it];
  1057. }
  1058. }
  1059. // only reestimate probs > 0 !
  1060. if (my emissionProbs [is] [k] > 0.0) {
  1061. thy bik_num [is] [k] += gammasum_k;
  1062. thy bik_denom [is] [k] += gammasum;
  1063. }
  1064. }
  1065. }
  1066. // For a left-to-right model the final state determines the transition prob to go to the END state
  1067. if (my leftToRight) {
  1068. thy aij_num [is] [my numberOfStates + 1] += thy gamma [is] [thy numberOfTimes];
  1069. thy aij_denom [is] [my numberOfStates + 1] += 1.0;
  1070. }
  1071. }
  1072. }
  1073. void HMM_HMMBaumWelch_reestimate (HMM me, HMMBaumWelch thee) {
  1074. double p;
  1075. /*
  1076. If we only have a couple of training sequences and they all happen to start with the same symbol,
  1077. one or more starting probabilities can be zero.
  1078. What to do with the P_k (see Rabiner formulas 109-110)?
  1079. */
  1080. for (integer is = 1; is <= my numberOfStates; is ++) {
  1081. /*
  1082. If we have not enough observation sequences it can happen that some probabilities
  1083. become zero. This means that for some future observation sequences the probability evaluation
  1084. returns p=0 for sequences where these transitions turn up. This makes recognition impossible and also comparisons between models are difficult.
  1085. We can prevent this from happening by asumimg a minimal probability for valid transitions
  1086. i.e. which have initially p > 0.
  1087. */
  1088. if (my transitionProbs [0] [is] > 0.0) {
  1089. p = thy aij_num [0] [is] / thy aij_denom [0] [is];
  1090. my transitionProbs [0] [is] = p > 0.0 ? p : thy minProb;
  1091. }
  1092. for (integer js = 1; js <= my numberOfStates; js ++) {
  1093. if (my transitionProbs [is] [js] > 0.0) {
  1094. p = thy aij_num [is] [js] / thy aij_denom [is] [js];
  1095. my transitionProbs [is] [js] = p > 0.0 ? p : thy minProb;
  1096. }
  1097. }
  1098. if (! my notHidden) {
  1099. for (integer k = 1; k <= my numberOfObservationSymbols; k ++) {
  1100. if (my emissionProbs [is] [k] > 0.0) {
  1101. p = thy bik_num [is] [k] / thy bik_denom [is] [k];
  1102. my emissionProbs [is] [k] = p > 0.0 ? p : thy minProb;
  1103. }
  1104. }
  1105. }
  1106. if (my leftToRight && my transitionProbs [is] [my numberOfStates + 1] > 0.0) {
  1107. p = thy aij_num [is] [my numberOfStates + 1] / thy aij_denom [is] [my numberOfStates + 1];
  1108. my transitionProbs [is] [my numberOfStates + 1] = p > 0.0 ? p : thy minProb;
  1109. }
  1110. }
  1111. }
  1112. void HMM_HMMBaumWelch_forward (HMM me, HMMBaumWelch thee, integer *obs) {
  1113. // initialise at t = 1 & scale
  1114. thy scale [1] = 0.0;
  1115. for (integer js = 1; js <= my numberOfStates; js ++) {
  1116. thy alpha [js] [1] = my transitionProbs [0] [js] * my emissionProbs [js] [obs [1]];
  1117. thy scale [1] += thy alpha [js] [1];
  1118. }
  1119. for (integer js = 1; js <= my numberOfStates; js ++) {
  1120. thy alpha [js] [1] /= thy scale [1];
  1121. }
  1122. // recursion
  1123. for (integer it = 2; it <= thy numberOfTimes; it ++) {
  1124. thy scale [it] = 0.0;
  1125. for (integer js = 1; js <= my numberOfStates; js ++) {
  1126. double sum = 0.0;
  1127. for (integer is = 1; is <= my numberOfStates; is ++) {
  1128. sum += thy alpha [is] [it - 1] * my transitionProbs [is] [js];
  1129. }
  1130. thy alpha [js] [it] = sum * my emissionProbs [js] [obs [it]];
  1131. thy scale [it] += thy alpha [js] [it];
  1132. }
  1133. for (integer js = 1; js <= my numberOfStates; js ++) {
  1134. thy alpha [js] [it] /= thy scale [it];
  1135. }
  1136. }
  1137. for (integer it = 1; it <= thy numberOfTimes; it ++) {
  1138. thy lnProb += log (thy scale [it]);
  1139. }
  1140. }
  1141. void HMM_HMMBaumWelch_backward (HMM me, HMMBaumWelch thee, integer *obs) {
  1142. for (integer is = 1; is <= my numberOfStates; is ++) {
  1143. thy beta [is] [thy numberOfTimes] = 1.0 / thy scale [thy numberOfTimes];
  1144. }
  1145. for (integer it = thy numberOfTimes - 1; it >= 1; it --) {
  1146. for (integer is = 1; is <= my numberOfStates; is ++) {
  1147. longdouble sum = 0.0;
  1148. for (integer js = 1; js <= my numberOfStates; js ++) {
  1149. sum += thy beta [js] [it + 1] * my transitionProbs [is] [js] * my emissionProbs [js] [obs [it + 1]];
  1150. }
  1151. thy beta [is] [it] = sum / thy scale [it];
  1152. }
  1153. }
  1154. }
  1155. /*************************** HMM decoding ***********************************/
  1156. // precondition: valid symbols, i.e. 1 <= o [i] <= my numberOfSymbols for i=1..nt
  1157. void HMM_HMMViterbi_decode (HMM me, HMMViterbi thee, integer *obs) {
  1158. integer ntimes = thy numberOfTimes;
  1159. // initialisation
  1160. for (integer is = 1; is <= my numberOfStates; is ++) {
  1161. thy viterbi [is] [1] = my transitionProbs [0] [is] * my emissionProbs [is] [obs [1]];
  1162. thy bp [is] [1] = 0;
  1163. }
  1164. // recursion
  1165. for (integer it = 2; it <= ntimes; it ++) {
  1166. for (integer is = 1; is <= my numberOfStates; is ++) {
  1167. // all transitions isp -> is from previous time to current
  1168. double max_score = -1; // any negative number is ok
  1169. for (integer isp = 1; isp <= my numberOfStates; isp ++) {
  1170. double score = thy viterbi [isp] [it - 1] * my transitionProbs [isp] [is]; // * my emissionProbs [is] [obs [it]]
  1171. if (score > max_score) {
  1172. max_score = score;
  1173. thy bp [is] [it] = isp;
  1174. }
  1175. }
  1176. thy viterbi [is] [it] = max_score * my emissionProbs [is] [ obs [it] ];
  1177. }
  1178. }
  1179. // path starts at state with best end probability
  1180. thy path [ntimes] = 1;
  1181. thy prob = thy viterbi [1] [ntimes];
  1182. for (integer is = 2; is <= my numberOfStates; is ++) {
  1183. if (thy viterbi [is] [ntimes] > thy prob) {
  1184. thy prob = thy viterbi [thy path [ntimes] = is] [ntimes];
  1185. }
  1186. }
  1187. // trace back and get path
  1188. for (integer it = ntimes; it > 1; it --) {
  1189. thy path [it - 1] = thy bp [thy path [it]] [it];
  1190. }
  1191. }
  1192. autoHMMStateSequence HMM_HMMObservationSequence_to_HMMStateSequence (HMM me, HMMObservationSequence thee) {
  1193. try {
  1194. autoStringsIndex si = HMM_HMMObservationSequence_to_StringsIndex (me, thee);
  1195. integer *obs = si -> classIndex.at; // convenience
  1196. integer numberOfUnknowns = StringsIndex_countItems (si.get(), 0);
  1197. Melder_require (numberOfUnknowns == 0, U"Unknown observation symbol(s) (# = ", numberOfUnknowns, U").");
  1198. integer numberOfTimes = thy rows.size;
  1199. autoHMMViterbi v = HMM_to_HMMViterbi (me, obs, numberOfTimes);
  1200. autoHMMStateSequence him = HMMStateSequence_create (numberOfTimes);
  1201. // trace the path and get states
  1202. for (integer it = 1; it <= numberOfTimes; it ++) {
  1203. HMMState hmms = my states->at [v -> path [it]];
  1204. his strings [it] = Melder_dup (hmms -> label.get());
  1205. his numberOfStrings ++;
  1206. }
  1207. return him;
  1208. } catch (MelderError) {
  1209. Melder_throw (me, U": no HMMStateSequence created.");
  1210. }
  1211. }
  1212. double HMM_HMMStateSequence_getProbability (HMM me, HMMStateSequence thee) {
  1213. autoStringsIndex si = HMM_HMMStateSequence_to_StringsIndex (me, thee);
  1214. integer numberOfUnknowns = StringsIndex_countItems (si.get(), 0);
  1215. integer *index = si -> classIndex.at;
  1216. if (index == 0) {
  1217. return undefined;
  1218. }
  1219. if (numberOfUnknowns > 0) {
  1220. Melder_warning (U"Unknown states (# = ", numberOfUnknowns, U").");
  1221. return undefined;
  1222. }
  1223. double p0 = my transitionProbs [0] [index [1]];
  1224. Melder_require (p0 > 0.0, U"You should not start with a zero probability state.");
  1225. double lnp = log (p0);
  1226. for (integer it = 2; it <= thy numberOfStrings; it ++) {
  1227. lnp += log (my transitionProbs [index [it - 1]] [index [it]]);
  1228. }
  1229. return lnp;
  1230. }
  1231. double HMM_getProbabilityAtTimeBeingInState (HMM me, integer itime, integer istate) {
  1232. if (istate < 1 || istate > my numberOfStates) {
  1233. return undefined;
  1234. }
  1235. autoNUMvector<double> scale (1, itime);
  1236. autoNUMvector<double> alpha_t (1, my numberOfStates);
  1237. autoNUMvector<double> alpha_tm1 (1, my numberOfStates);
  1238. for (integer js = 1; js <= my numberOfStates; js ++) {
  1239. alpha_t [js] = my transitionProbs [0] [js];
  1240. scale [1] += alpha_t [js];
  1241. }
  1242. for (integer js = 1; js <= my numberOfStates; js ++) {
  1243. alpha_t [js] /= scale [1];
  1244. }
  1245. // recursion
  1246. for (integer it = 2; it <= itime; it ++) {
  1247. for (integer js = 1; js <= my numberOfStates; js ++) {
  1248. alpha_tm1 [js] = alpha_t [js];
  1249. }
  1250. for (integer js = 1; js <= my numberOfStates; js ++) {
  1251. longdouble sum = 0.0;
  1252. for (integer is = 1; is <= my numberOfStates; is ++) {
  1253. sum += alpha_tm1 [is] * my transitionProbs [is] [js];
  1254. }
  1255. alpha_t [js] = sum;
  1256. scale [it] += alpha_t [js];
  1257. }
  1258. for (integer js = 1; js <= my numberOfStates; js ++) {
  1259. alpha_t [js] /= scale [it];
  1260. }
  1261. }
  1262. longdouble lnp = 0.0;
  1263. for (integer it = 1; it <= itime; it ++) {
  1264. lnp += log (scale [it]);
  1265. }
  1266. lnp = alpha_t [istate] > 0 ? lnp + log (alpha_t [istate]) : -INFINITY; // p = 0 -> ln(p)=-infinity // ppgb FIXME infinity is een laag getal
  1267. return lnp;
  1268. }
  1269. double HMM_getProbabilityAtTimeBeingInStateEmittingSymbol (HMM me, integer itime, integer istate, integer isymbol) {
  1270. // for a notHidden model emissionProbs may be zero!
  1271. if (isymbol < 1 || isymbol > my numberOfObservationSymbols || my emissionProbs [istate] [isymbol] == 0) {
  1272. return undefined;
  1273. }
  1274. double lnp = HMM_getProbabilityAtTimeBeingInState (me, itime, istate);
  1275. return ( isundef (lnp) ? undefined : lnp + log (my emissionProbs [istate] [isymbol]) );
  1276. }
  1277. double HMM_getProbabilityOfObservations (HMM me, integer *obs, integer numberOfTimes) {
  1278. autoNUMvector <double> scale (1, numberOfTimes);
  1279. autoNUMvector <double> alpha_t (1, my numberOfStates);
  1280. autoNUMvector <double> alpha_tm1 (1, my numberOfStates);
  1281. // initialise
  1282. for (integer js = 1; js <= my numberOfStates; js ++) {
  1283. alpha_t [js] = my transitionProbs [0] [js] * my emissionProbs [js] [obs [1]];
  1284. scale [1] += alpha_t [js];
  1285. }
  1286. Melder_require (scale [1] > 0.0, U"The observation sequence should not start with a symbol whose state has zero starting probability.");
  1287. for (integer js = 1; js <= my numberOfStates; js ++) {
  1288. alpha_t [js] /= scale [1];
  1289. }
  1290. // recursion
  1291. for (integer it = 2; it <= numberOfTimes; it ++) {
  1292. for (integer js = 1; js <= my numberOfStates; js ++) {
  1293. alpha_tm1 [js] = alpha_t [js];
  1294. }
  1295. for (integer js = 1; js <= my numberOfStates; js ++) {
  1296. longdouble sum = 0.0;
  1297. for (integer is = 1; is <= my numberOfStates; is ++) {
  1298. sum += alpha_tm1 [is] * my transitionProbs [is] [js];
  1299. }
  1300. alpha_t [js] = sum * my emissionProbs [js] [obs [it]];
  1301. scale [it] += alpha_t [js];
  1302. }
  1303. if (scale [it] <= 0.0) {
  1304. return -INFINITY;
  1305. }
  1306. for (integer js = 1; js <= my numberOfStates; js ++) {
  1307. alpha_t [js] /= scale [it];
  1308. }
  1309. }
  1310. double lnp = 0.0;
  1311. for (integer it = 1; it <= numberOfTimes; it ++) {
  1312. lnp += log (scale [it]);
  1313. }
  1314. return lnp;
  1315. }
  1316. double HMM_HMMObservationSequence_getProbability (HMM me, HMMObservationSequence thee) {
  1317. autoStringsIndex si = HMM_HMMObservationSequence_to_StringsIndex (me, thee);
  1318. integer *index = si -> classIndex.at;
  1319. integer numberOfUnknowns = StringsIndex_countItems (si.get(), 0);
  1320. Melder_require (numberOfUnknowns == 0, U"Unknown observations (# = ", numberOfUnknowns, U").");
  1321. return HMM_getProbabilityOfObservations (me, index, thy rows.size);
  1322. }
  1323. double HMM_HMMObservationSequence_getCrossEntropy (HMM me, HMMObservationSequence thee) {
  1324. double lnp = HMM_HMMObservationSequence_getProbability (me, thee);
  1325. return isundef (lnp) ? undefined :
  1326. -lnp / (NUMln10 * HMMObservationSequence_getNumberOfObservations (thee));
  1327. }
  1328. double HMM_HMMObservationSequence_getPerplexity (HMM me, HMMObservationSequence thee) {
  1329. double ce = HMM_HMMObservationSequence_getCrossEntropy (me, thee);
  1330. return isundef (ce) ? undefined : pow (2.0, ce);
  1331. }
  1332. autoHMM HMM_createFromHMMObservationSequence (HMMObservationSequence me, integer numberOfStates, int leftToRight) {
  1333. try {
  1334. autoHMM thee = Thing_new (HMM);
  1335. autoStrings s = HMMObservationSequence_to_Strings (me);
  1336. autoDistributions d = Strings_to_Distributions (s.get());
  1337. integer numberOfObservationSymbols = d -> numberOfRows;
  1338. thy notHidden = numberOfStates < 1;
  1339. numberOfStates = numberOfStates > 0 ? numberOfStates : numberOfObservationSymbols;
  1340. HMM_init (thee.get(), numberOfStates, numberOfObservationSymbols, leftToRight);
  1341. for (integer i = 1; i <= numberOfObservationSymbols; i ++) {
  1342. conststring32 label = d -> rowLabels [i].get();
  1343. autoHMMObservation hmmo = HMMObservation_create (label, 0, 0, 0);
  1344. HMM_addObservation_move (thee.get(), hmmo.move());
  1345. if (thy notHidden) {
  1346. autoHMMState hmms = HMMState_create (label);
  1347. HMM_addState_move (thee.get(), hmms.move());
  1348. }
  1349. }
  1350. if (! thy notHidden) {
  1351. HMM_setDefaultStates (thee.get());
  1352. }
  1353. HMM_setDefaultTransitionProbs (thee.get());
  1354. HMM_setDefaultStartProbs (thee.get());
  1355. HMM_setDefaultEmissionProbs (thee.get());
  1356. return thee;
  1357. } catch (MelderError) {
  1358. Melder_throw (me, U": no HMM created.");
  1359. }
  1360. }
  1361. autoTableOfReal HMMObservationSequence_to_TableOfReal_transitions (HMMObservationSequence me, int probabilities) {
  1362. try {
  1363. autoStrings thee = HMMObservationSequence_to_Strings (me);
  1364. autoTableOfReal him = Strings_to_TableOfReal_transitions (thee.get(), probabilities);
  1365. return him;
  1366. } catch (MelderError) {
  1367. Melder_throw (me, U": no transitions created.");
  1368. }
  1369. }
  1370. autoStringsIndex HMM_HMMObservationSequence_to_StringsIndex (HMM me, HMMObservationSequence thee) {
  1371. try {
  1372. autoStrings classes = Thing_new (Strings);
  1373. classes -> strings = autostring32vector (my numberOfObservationSymbols);
  1374. for (integer is = 1; is <= my numberOfObservationSymbols; is ++) {
  1375. HMMObservation hmmo = my observationSymbols->at [is];
  1376. classes -> strings [is] = Melder_dup (hmmo -> label.get());
  1377. classes -> numberOfStrings ++;
  1378. }
  1379. autoStrings obs = HMMObservationSequence_to_Strings (thee);
  1380. autoStringsIndex him = Stringses_to_StringsIndex (obs.get(), classes.get());
  1381. return him;
  1382. } catch (MelderError) {
  1383. Melder_throw (me, U": no StringsIndex created.");
  1384. }
  1385. }
  1386. autoStringsIndex HMM_HMMStateSequence_to_StringsIndex (HMM me, HMMStateSequence thee) {
  1387. try {
  1388. autoStrings classes = Thing_new (Strings);
  1389. classes -> strings = autostring32vector (my numberOfObservationSymbols);
  1390. for (integer is = 1; is <= my numberOfStates; is ++) {
  1391. HMMState hmms = my states->at [is];
  1392. classes -> strings [is] = Melder_dup (hmms -> label.get());
  1393. classes -> numberOfStrings ++;
  1394. }
  1395. autoStrings sts = HMMStateSequence_to_Strings (thee);
  1396. autoStringsIndex him = Stringses_to_StringsIndex (sts.get(), classes.get());
  1397. return him;
  1398. } catch (MelderError) {
  1399. Melder_throw (me, U": no StringsIndex created.");
  1400. }
  1401. }
  1402. autoTableOfReal HMM_HMMObservationSequence_to_TableOfReal_transitions (HMM me, HMMObservationSequence thee, int probabilities) {
  1403. try {
  1404. autoStringsIndex si = HMM_HMMObservationSequence_to_StringsIndex (me, thee);
  1405. autoTableOfReal him = StringsIndex_to_TableOfReal_transitions (si.get(), probabilities);
  1406. return him;
  1407. } catch (MelderError) {
  1408. Melder_throw (me, U": no transition table created for HMMObservationSequence.");
  1409. }
  1410. }
  1411. autoTableOfReal HMM_HMMStateSequence_to_TableOfReal_transitions (HMM me, HMMStateSequence thee, int probabilities) {
  1412. try {
  1413. autoStringsIndex si = HMM_HMMStateSequence_to_StringsIndex (me, thee);
  1414. autoTableOfReal him = StringsIndex_to_TableOfReal_transitions (si.get(), probabilities);
  1415. return him;
  1416. } catch (MelderError) {
  1417. Melder_throw (me, U": no transition table created for HMMStateSequence.");
  1418. }
  1419. }
  1420. autoTableOfReal StringsIndex_to_TableOfReal_transitions (StringsIndex me, int probabilities) {
  1421. try {
  1422. integer numberOfTypes = my classes->size;
  1423. autoTableOfReal thee = TableOfReal_create (numberOfTypes + 1, numberOfTypes + 1);
  1424. for (integer i = 1; i <= numberOfTypes; i ++) {
  1425. SimpleString s = (SimpleString) my classes->at [i];
  1426. TableOfReal_setRowLabel (thee.get(), i, s -> string.get());
  1427. TableOfReal_setColumnLabel (thee.get(), i, s -> string.get());
  1428. }
  1429. for (integer i = 2; i <= my numberOfItems; i ++) {
  1430. if (my classIndex [i - 1] > 0 && my classIndex [i] > 0) { // a zero is a restart!
  1431. thy data [my classIndex [i-1]] [my classIndex [i]] ++;
  1432. }
  1433. }
  1434. longdouble sum = 0.0;
  1435. for (integer i = 1; i <= numberOfTypes; i ++) {
  1436. double rowSum = 0.0, colSum = 0.0;
  1437. for (integer j = 1; j <= numberOfTypes; j ++) {
  1438. rowSum += thy data [i] [j];
  1439. }
  1440. thy data [i] [numberOfTypes + 1] = rowSum;
  1441. for (integer j = 1; j <= numberOfTypes; j ++) {
  1442. colSum += thy data [j] [i];
  1443. }
  1444. thy data [numberOfTypes + 1] [i] = colSum;
  1445. sum += colSum;
  1446. }
  1447. thy data [numberOfTypes + 1] [numberOfTypes + 1] = sum;
  1448. if (probabilities && sum > 0.0) {
  1449. for (integer i = 1; i <= numberOfTypes; i ++) {
  1450. if (thy data [i] [numberOfTypes + 1] > 0.0) {
  1451. for (integer j = 1; j <= numberOfTypes; j ++) {
  1452. thy data [i] [j] /= thy data [i] [numberOfTypes + 1];
  1453. }
  1454. }
  1455. }
  1456. for (integer i = 1; i <= numberOfTypes; i ++) {
  1457. thy data [i] [numberOfTypes + 1] /= sum;
  1458. thy data [numberOfTypes + 1] [i] /= sum;
  1459. }
  1460. }
  1461. return thee;
  1462. } catch (MelderError) {
  1463. Melder_throw (me, U": no transition table created.");
  1464. }
  1465. }
  1466. autoTableOfReal Strings_to_TableOfReal_transitions (Strings me, int probabilities) {
  1467. try {
  1468. autoStringsIndex him = Strings_to_StringsIndex (me);
  1469. autoTableOfReal thee = StringsIndex_to_TableOfReal_transitions (him.get(), probabilities);
  1470. return thee;
  1471. } catch (MelderError) {
  1472. Melder_throw (me, U": no transition table created.");
  1473. }
  1474. }
  1475. /* End of file HMM.cpp */