PairDistribution.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. /* PairDistribution.cpp
  2. *
  3. * Copyright (C) 1997-2012,2013,2015,2016,2017 Paul Boersma
  4. *
  5. * This code is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation; either version 2 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This code is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. * See the GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this work. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. #include "PairDistribution.h"
  19. #include "oo_DESTROY.h"
  20. #include "PairDistribution_def.h"
  21. #include "oo_COPY.h"
  22. #include "PairDistribution_def.h"
  23. #include "oo_EQUAL.h"
  24. #include "PairDistribution_def.h"
  25. #include "oo_CAN_WRITE_AS_ENCODING.h"
  26. #include "PairDistribution_def.h"
  27. #include "oo_WRITE_TEXT.h"
  28. #include "PairDistribution_def.h"
  29. #include "oo_READ_TEXT.h"
  30. #include "PairDistribution_def.h"
  31. #include "oo_WRITE_BINARY.h"
  32. #include "PairDistribution_def.h"
  33. #include "oo_READ_BINARY.h"
  34. #include "PairDistribution_def.h"
  35. #include "oo_DESCRIPTION.h"
  36. #include "PairDistribution_def.h"
  37. Thing_implement (PairProbability, Daata, 0);
  38. Thing_implement (PairDistribution, Daata, 0);
  39. void structPairDistribution :: v_info () {
  40. PairDistribution_Parent :: v_info ();
  41. MelderInfo_writeLine (U"Number of pairs: ", pairs.size);
  42. }
  43. autoPairProbability PairProbability_create (conststring32 string1, conststring32 string2, double weight) {
  44. autoPairProbability me = Thing_new (PairProbability);
  45. my string1 = Melder_dup (string1);
  46. my string2 = Melder_dup (string2);
  47. my weight = weight;
  48. return me;
  49. }
  50. autoPairDistribution PairDistribution_create () {
  51. try {
  52. autoPairDistribution me = Thing_new (PairDistribution);
  53. return me;
  54. } catch (MelderError) {
  55. Melder_throw (U"PairDistribution not created.");
  56. }
  57. }
  58. int PairProbability_compare (PairProbability me, PairProbability thee) noexcept {
  59. return str32cmp (my string1.get(), thy string1.get());
  60. }
  61. static void PairDistribution_checkSpecifiedPairNumber (PairDistribution me, integer pairNumber) {
  62. if (pairNumber < 1)
  63. Melder_throw (me, U": the specified pair number is ", pairNumber, U", but should be at least 1.");
  64. if (pairNumber > my pairs.size)
  65. Melder_throw (me, U": the specified pair number is ", pairNumber, U", but should be at most my number of pairs (", my pairs.size, U").");
  66. }
  67. conststring32 PairDistribution_getString1 (PairDistribution me, integer pairNumber) {
  68. try {
  69. PairDistribution_checkSpecifiedPairNumber (me, pairNumber);
  70. PairProbability prob = my pairs.at [pairNumber];
  71. return prob -> string1.get();
  72. } catch (MelderError) {
  73. Melder_throw (me, U": string1 not retrieved.");
  74. }
  75. }
  76. conststring32 PairDistribution_getString2 (PairDistribution me, integer pairNumber) {
  77. try {
  78. PairDistribution_checkSpecifiedPairNumber (me, pairNumber);
  79. PairProbability prob = my pairs.at [pairNumber];
  80. return prob -> string2.get();
  81. } catch (MelderError) {
  82. Melder_throw (me, U": string2 not retrieved.");
  83. }
  84. }
  85. double PairDistribution_getWeight (PairDistribution me, integer pairNumber) {
  86. try {
  87. PairDistribution_checkSpecifiedPairNumber (me, pairNumber);
  88. PairProbability prob = my pairs.at [pairNumber];
  89. return prob -> weight;
  90. } catch (MelderError) {
  91. Melder_throw (me, U": weight not retrieved.");
  92. }
  93. }
  94. void PairDistribution_add (PairDistribution me, conststring32 string1, conststring32 string2, double weight) {
  95. autoPairProbability pair = PairProbability_create (string1, string2, weight);
  96. my pairs.addItem_move (pair.move());
  97. }
  98. void PairDistribution_removeZeroWeights (PairDistribution me) {
  99. for (integer ipair = my pairs.size; ipair > 0; ipair --) {
  100. PairProbability prob = my pairs.at [ipair];
  101. if (prob -> weight <= 0.0) {
  102. my pairs.removeItem (ipair);
  103. }
  104. }
  105. }
  106. void PairDistribution_swapInputsAndOutputs (PairDistribution me) {
  107. for (integer ipair = my pairs.size; ipair > 0; ipair --) {
  108. PairProbability prob = my pairs.at [ipair];
  109. std::swap (prob -> string1, prob -> string2); // that this really swap?
  110. }
  111. }
  112. static double PairDistributions_getTotalWeight_checkPositive (PairDistribution me) {
  113. longdouble totalWeight = 0.0;
  114. for (integer ipair = 1; ipair <= my pairs.size; ipair ++) {
  115. PairProbability prob = my pairs.at [ipair];
  116. totalWeight += prob -> weight;
  117. }
  118. if (totalWeight <= 0.0) {
  119. Melder_throw (me, U": the total probability weight is ", Melder_half (totalWeight), U" but should be greater than zero for this operation.");
  120. }
  121. return (double) totalWeight;
  122. }
  123. void PairDistribution_to_Stringses (PairDistribution me, integer nout, autoStrings *strings1_out, autoStrings *strings2_out) {
  124. try {
  125. integer nin = my pairs.size, iin;
  126. if (nin < 1)
  127. Melder_throw (U"No candidates.");
  128. if (nout < 1)
  129. Melder_throw (U"Number of generated string pairs should be positive.");
  130. double total = PairDistributions_getTotalWeight_checkPositive (me);
  131. autoStrings strings1 = Thing_new (Strings);
  132. strings1 -> numberOfStrings = nout;
  133. strings1 -> strings = autostring32vector (nout);
  134. autoStrings strings2 = Thing_new (Strings);
  135. strings2 -> numberOfStrings = nout;
  136. strings2 -> strings = autostring32vector (nout);
  137. for (integer iout = 1; iout <= nout; iout ++) {
  138. do {
  139. double rand = NUMrandomUniform (0, total), sum = 0.0;
  140. for (iin = 1; iin <= nin; iin ++) {
  141. PairProbability prob = my pairs.at [iin];
  142. sum += prob -> weight;
  143. if (rand <= sum) break;
  144. }
  145. } while (iin > nin); /* Guard against rounding errors. */
  146. PairProbability prob = my pairs.at [iin];
  147. if (! prob -> string1 || ! prob -> string2)
  148. Melder_throw (U"No string in probability pair ", iin, U".");
  149. strings1 -> strings [iout] = Melder_dup (prob -> string1.get());
  150. strings2 -> strings [iout] = Melder_dup (prob -> string2.get());
  151. }
  152. *strings1_out = strings1.move();
  153. *strings2_out = strings2.move();
  154. } catch (MelderError) {
  155. Melder_throw (me, U": generation of Stringses not performed.");
  156. }
  157. }
  158. void PairDistribution_peekPair (PairDistribution me, conststring32 *out_string1, conststring32 *out_string2) {
  159. try {
  160. *out_string1 = *out_string2 = nullptr;
  161. double total = 0.0;
  162. integer nin = my pairs.size, iin;
  163. PairProbability prob;
  164. if (nin < 1) Melder_throw (U"No candidates.");
  165. for (iin = 1; iin <= nin; iin ++) {
  166. prob = my pairs.at [iin];
  167. total += prob -> weight;
  168. }
  169. do {
  170. double rand = NUMrandomUniform (0, total), sum = 0.0;
  171. for (iin = 1; iin <= nin; iin ++) {
  172. prob = my pairs.at [iin];
  173. sum += prob -> weight;
  174. if (rand <= sum) break;
  175. }
  176. } while (iin > nin); // guard against rounding errors
  177. prob = my pairs.at [iin];
  178. if (! prob -> string1 || ! prob -> string2) Melder_throw (U"No string in probability pair ", iin, U".");
  179. *out_string1 = prob -> string1.get();
  180. *out_string2 = prob -> string2.get();
  181. } catch (MelderError) {
  182. Melder_throw (me, U": pair not peeked.");
  183. }
  184. }
  185. static double PairDistribution_getFractionCorrect (PairDistribution me, int which) {
  186. try {
  187. double correct = 0.0;
  188. integer pairmin = 1, ipair;
  189. autoPairDistribution thee = Data_copy (me);
  190. thy pairs.sort (PairProbability_compare);
  191. double total = PairDistributions_getTotalWeight_checkPositive (thee.get());
  192. do {
  193. integer pairmax = pairmin;
  194. const conststring32 firstInput = thy pairs.at [pairmin] -> string1.get();
  195. for (ipair = pairmin + 1; ipair <= thy pairs.size; ipair ++) {
  196. PairProbability prob = thy pairs.at [ipair];
  197. if (! str32equ (prob -> string1.get(), firstInput)) {
  198. pairmax = ipair - 1;
  199. break;
  200. }
  201. }
  202. if (ipair > thy pairs.size) pairmax = thy pairs.size;
  203. if (which == 0) {
  204. double pmax = 0.0;
  205. for (ipair = pairmin; ipair <= pairmax; ipair ++) {
  206. PairProbability prob = thy pairs.at [ipair];
  207. double p = prob -> weight / total;
  208. if (p > pmax) pmax = p;
  209. }
  210. correct += pmax;
  211. } else {
  212. double sum = 0.0, p2 = 0.0;
  213. for (ipair = pairmin; ipair <= pairmax; ipair ++) {
  214. PairProbability prob = thy pairs.at [ipair];
  215. double p = prob -> weight / total;
  216. sum += p;
  217. p2 += p * p;
  218. }
  219. correct += p2 / sum;
  220. }
  221. pairmin = pairmax + 1;
  222. } while (pairmin <= thy pairs.size);
  223. return correct;
  224. } catch (MelderError) {
  225. Melder_throw (me, U": could not compute my fraction correct.");
  226. }
  227. }
  228. double PairDistribution_getFractionCorrect_maximumLikelihood (PairDistribution me) {
  229. return PairDistribution_getFractionCorrect (me, 0);
  230. }
  231. double PairDistribution_getFractionCorrect_probabilityMatching (PairDistribution me) {
  232. return PairDistribution_getFractionCorrect (me, 1);
  233. }
  234. double PairDistribution_Distributions_getFractionCorrect (PairDistribution me, Distributions dist, integer column) {
  235. try {
  236. double correct = 0.0;
  237. integer pairmin = 1;
  238. char32 string [1000];
  239. Distributions_checkSpecifiedColumnNumberWithinRange (dist, column);
  240. autoPairDistribution thee = Data_copy (me);
  241. thy pairs.sort (PairProbability_compare);
  242. double total = PairDistributions_getTotalWeight_checkPositive (thee.get());
  243. do {
  244. integer pairmax = pairmin, length, ipair;
  245. double sum = 0.0, sumDist = 0.0;
  246. const conststring32 firstInput = thy pairs.at [pairmin] -> string1.get();
  247. for (ipair = pairmin + 1; ipair <= thy pairs.size; ipair ++) {
  248. PairProbability prob = thy pairs.at [ipair];
  249. if (! str32equ (prob -> string1.get(), firstInput)) {
  250. pairmax = ipair - 1;
  251. break;
  252. }
  253. }
  254. if (ipair > thy pairs.size) pairmax = thy pairs.size;
  255. for (ipair = pairmin; ipair <= pairmax; ipair ++) {
  256. PairProbability prob = thy pairs.at [ipair];
  257. double p = prob -> weight / total, pout = 0.0;
  258. Melder_sprint (string, 1000, prob -> string1.get(), U" \\-> ", prob -> string2.get());
  259. for (integer idist = 1; idist <= dist -> numberOfRows; idist ++) {
  260. if (str32equ (string, dist -> rowLabels [idist].get())) {
  261. pout = dist -> data [idist] [column];
  262. break;
  263. }
  264. }
  265. sum += p * pout;
  266. }
  267. Melder_sprint (string, 1000, firstInput, U" \\-> ");
  268. length = str32len (string);
  269. for (integer idist = 1; idist <= dist -> numberOfRows; idist ++) {
  270. if (str32nequ (string, dist -> rowLabels [idist].get(), length)) {
  271. sumDist += dist -> data [idist] [column];
  272. }
  273. }
  274. if (sumDist != 0.0) correct += sum / sumDist;
  275. pairmin = pairmax + 1;
  276. } while (pairmin <= thy pairs.size);
  277. return correct;
  278. } catch (MelderError) {
  279. Melder_throw (me, U" & ", dist, U": could not compute our fraction correct.");
  280. }
  281. }
  282. autoTable PairDistribution_to_Table (PairDistribution me) {
  283. try {
  284. autoTable thee = Table_createWithColumnNames (my pairs.size, U"string1 string2 weight");
  285. for (integer ipair = 1; ipair <= my pairs.size; ipair ++) {
  286. PairProbability prob = my pairs.at [ipair];
  287. Table_setStringValue (thee.get(), ipair, 1, prob -> string1.get());
  288. Table_setStringValue (thee.get(), ipair, 2, prob -> string2.get());
  289. Table_setNumericValue (thee.get(), ipair, 3, prob -> weight);
  290. }
  291. return thee;
  292. } catch (MelderError) {
  293. Melder_throw (me, U": not converted to Table.");
  294. }
  295. }
  296. /* End of file PairDistribution.cpp */