KNN.cpp 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752
  1. /* KNN.cpp
  2. *
  3. * Copyright (C) 2008 Ola So"der, 2010-2012,2016-2018 Paul Boersma
  4. *
  5. * This code is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation; either version 2 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This code is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  13. * General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this work. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. /*
  19. * os 2008/05/29 initial release
  20. * pb 2010/06/06 removed some array-creations-on-the-stack
  21. * pb 2011/04/12 C++
  22. * pb 2011/04/13 removed several memory leaks
  23. * pb 2011/07/07 some exception safety
  24. */
  25. #include "KNN.h"
  26. #include "KNN_threads.h"
  27. #include "OlaP.h"
  28. #include "oo_DESTROY.h"
  29. #include "KNN_def.h"
  30. #include "oo_COPY.h"
  31. #include "KNN_def.h"
  32. #include "oo_EQUAL.h"
  33. #include "KNN_def.h"
  34. #include "oo_CAN_WRITE_AS_ENCODING.h"
  35. #include "KNN_def.h"
  36. #include "oo_WRITE_TEXT.h"
  37. #include "KNN_def.h"
  38. #include "oo_WRITE_BINARY.h"
  39. #include "KNN_def.h"
  40. #include "oo_READ_TEXT.h"
  41. #include "KNN_def.h"
  42. #include "oo_READ_BINARY.h"
  43. #include "KNN_def.h"
  44. #include "oo_DESCRIPTION.h"
  45. #include "KNN_def.h"
  46. Thing_implement (KNN, Daata, 0);
  47. /////////////////////////////////////////////////////////////////////////////////////////////
  48. // Praat specifics //
  49. /////////////////////////////////////////////////////////////////////////////////////////////
  50. void structKNN :: v_info ()
  51. {
  52. structDaata :: v_info ();
  53. MelderInfo_writeLine (U"Size of instancebase: ", nInstances);
  54. }
  55. /////////////////////////////////////////////////////////////////////////////////////////////
  56. // Creation //
  57. /////////////////////////////////////////////////////////////////////////////////////////////
  58. autoKNN KNN_create ()
  59. {
  60. try {
  61. autoKNN me = Thing_new (KNN);
  62. my nInstances = 0;
  63. return me;
  64. } catch (MelderError) {
  65. Melder_throw (U"KNN classifier not created.");
  66. }
  67. }
  68. /////////////////////////////////////////////////////////////////////////////////////////////
  69. // Learning //
  70. /////////////////////////////////////////////////////////////////////////////////////////////
  71. int KNN_learn
  72. (
  73. ///////////////////////////////
  74. // Parameters //
  75. ///////////////////////////////
  76. KNN me, // the classifier to be trained
  77. //
  78. PatternList p, // source pattern
  79. //
  80. Categories c, // target categories
  81. //
  82. int method, // method <- REPLACE or APPEND
  83. //
  84. int ordering // ordering <- SHUFFLE?
  85. )
  86. {
  87. if (c->size == p->ny) // the number of input vectors must
  88. { // equal the number of categories.
  89. switch (method)
  90. {
  91. case kOla_REPLACE: // in REPLACE mode simply
  92. // dispose of the current
  93. my input = Data_copy (p); // LEAK
  94. my output = Data_copy (c);
  95. my nInstances = c->size;
  96. break;
  97. case kOla_APPEND: // in APPEND mode a new
  98. // instance base is formed
  99. // by merging the new and
  100. // the old.
  101. //
  102. if (p->nx == (my input)->nx) // the number of features
  103. // of the old and new
  104. // instance base must
  105. // match; otherwise merging
  106. { // won't be possible.
  107. /*
  108. * Create without change.
  109. */
  110. autoMatrix matrix = Matrix_appendRows (my input.get(), p, classPatternList);
  111. autoPatternList tinput = matrix.static_cast_move <structPatternList>();
  112. autoCategories toutput = Data_copy (my output.get());
  113. toutput -> merge (c);
  114. /*
  115. * Change without error.
  116. */
  117. my input = tinput.move();
  118. my output = toutput.move();
  119. my nInstances += p -> ny;
  120. } else { // fail
  121. return kOla_DIMENSIONALITY_MISMATCH;
  122. }
  123. break;
  124. }
  125. if (ordering == kOla_SHUFFLE) // shuffle the instance base
  126. KNN_shuffleInstances(me);
  127. } else { // fail
  128. return kOla_PATTERN_CATEGORIES_MISMATCH;
  129. }
  130. return kOla_SUCCESS; // success
  131. }
  132. /////////////////////////////////////////////////////////////////////////////////////////////
  133. // Classification - To Categories //
  134. /////////////////////////////////////////////////////////////////////////////////////////////
  135. typedef struct
  136. {
  137. KNN me;
  138. PatternList ps;
  139. integer * output;
  140. FeatureWeights fws;
  141. integer k;
  142. int dist;
  143. integer istart;
  144. integer istop;
  145. } KNN_input_ToCategories_t;
  146. autoCategories KNN_classifyToCategories
  147. (
  148. ///////////////////////////////
  149. // Parameters //
  150. ///////////////////////////////
  151. KNN me, // the classifier being used
  152. //
  153. PatternList ps, // the pattern to classify
  154. //
  155. FeatureWeights fws, // feature weights
  156. //
  157. integer k, // the number of sought after neighbours
  158. //
  159. int dist // distance weighting
  160. )
  161. {
  162. int nthreads = KNN_getNumberOfCPUs();
  163. integer *outputindices = NUMvector <integer> (0, ps -> ny);
  164. integer chunksize = ps -> ny / nthreads;
  165. Melder_assert (nthreads > 0);
  166. Melder_assert (k > 0 && k <= my nInstances);
  167. if(chunksize < 1) {
  168. chunksize = 1;
  169. nthreads = ps -> ny;
  170. }
  171. integer istart = 1;
  172. integer istop = chunksize;
  173. autoCategories output = Categories_create ();
  174. KNN_input_ToCategories_t ** input = (KNN_input_ToCategories_t **) malloc (nthreads * sizeof (KNN_input_ToCategories_t *));
  175. if (! input)
  176. return autoCategories();
  177. for (int i = 0; i < nthreads; i ++) {
  178. input [i] = (KNN_input_ToCategories_t *) malloc (sizeof (KNN_input_ToCategories_t));
  179. if (! input [i]) {
  180. while (input [i --])
  181. free (input [i]);
  182. free (input);
  183. return autoCategories();
  184. }
  185. }
  186. for (int i = 0; i < nthreads; i ++) {
  187. input [i] -> me = me;
  188. input [i] -> ps = ps;
  189. input [i] -> output = outputindices;
  190. input [i] -> fws = fws;
  191. input [i] -> k = k;
  192. input [i] -> dist = dist;
  193. input [i] -> istart = istart;
  194. if (istop + chunksize > ps -> ny) {
  195. input [i] -> istop = ps -> ny;
  196. break;
  197. } else {
  198. input [i] -> istop = istop;
  199. istart = istop + 1;
  200. istop += chunksize;
  201. }
  202. }
  203. enum KNN_thread_status * error = (enum KNN_thread_status *) KNN_threadDistribution(KNN_classifyToCategoriesAux, (void **) input, nthreads);
  204. //void *error = KNN_classifyToCategoriesAux (input [0]);
  205. for (int i = 0; i < nthreads; i ++)
  206. free (input [i]);
  207. free (input);
  208. if (error) { // Something went very wrong, you ought to inform the user!
  209. free (error);
  210. return autoCategories();
  211. }
  212. for (integer i = 1; i <= ps -> ny; i ++)
  213. output -> addItem_move (Data_copy (my output->at [outputindices [i]]));
  214. NUMvector_free (outputindices, 0);
  215. return output;
  216. }
  217. void * KNN_classifyToCategoriesAux
  218. (
  219. ///////////////////////////////
  220. // Parameters //
  221. ///////////////////////////////
  222. void * input
  223. )
  224. {
  225. Melder_assert(((KNN_input_ToCategories_t *) input)->istart > 0 &&
  226. ((KNN_input_ToCategories_t *) input)->istop > 0 &&
  227. ((KNN_input_ToCategories_t *) input)->istart <= ((KNN_input_ToCategories_t *) input)->ps->ny &&
  228. ((KNN_input_ToCategories_t *) input)->istop <= ((KNN_input_ToCategories_t *) input)->ps->ny &&
  229. ((KNN_input_ToCategories_t *) input)->istart <= ((KNN_input_ToCategories_t *) input)->istop);
  230. integer ncollected;
  231. integer ncategories;
  232. integer *indices = NUMvector <integer> (0, ((KNN_input_ToCategories_t *) input)->k);
  233. integer *freqindices = NUMvector <integer> (0, ((KNN_input_ToCategories_t *) input)->k);
  234. double *distances = NUMvector <double> (0, ((KNN_input_ToCategories_t *) input)->k);
  235. double *freqs = NUMvector <double> (0, ((KNN_input_ToCategories_t *) input)->k);
  236. for (integer y = ((KNN_input_ToCategories_t *) input)->istart; y <= ((KNN_input_ToCategories_t *) input)->istop; ++y)
  237. {
  238. /////////////////////////////////////////
  239. // Localizing the k nearest neighbours //
  240. /////////////////////////////////////////
  241. ncollected = KNN_kNeighbours (
  242. ((KNN_input_ToCategories_t *) input)->ps,
  243. ((KNN_input_ToCategories_t *) input)->me->input.get(),
  244. ((KNN_input_ToCategories_t *) input)->fws, y,
  245. ((KNN_input_ToCategories_t *) input)->k, indices, distances
  246. );
  247. /////////////////////////////////////////////////
  248. // Computing frequencies and average distances //
  249. /////////////////////////////////////////////////
  250. ncategories = KNN_kIndicesToFrequenciesAndDistances (
  251. ((KNN_input_ToCategories_t *) input)->me->output.get(),
  252. ((KNN_input_ToCategories_t *) input)->k,
  253. indices, distances, freqs, freqindices
  254. );
  255. ////////////////////////
  256. // Distance weighting //
  257. ////////////////////////
  258. switch(((KNN_input_ToCategories_t *) input)->dist)
  259. {
  260. case kOla_DISTANCE_WEIGHTED_VOTING:
  261. for (integer c = 0; c < ncategories; ++c)
  262. freqs[c] *= 1 / OlaMAX(distances[c], kOla_MINFLOAT);
  263. break;
  264. case kOla_SQUARED_DISTANCE_WEIGHTED_VOTING:
  265. for (integer c = 0; c < ncategories; ++c)
  266. freqs[c] *= 1 / OlaMAX(OlaSQUARE(distances[c]), kOla_MINFLOAT);
  267. }
  268. KNN_normalizeFloatArray(freqs, ncategories);
  269. ((KNN_input_ToCategories_t *) input)->output[y] = freqindices[KNN_max(freqs, ncategories)];
  270. }
  271. NUMvector_free (indices, 0);
  272. NUMvector_free (freqindices, 0);
  273. NUMvector_free (distances, 0);
  274. NUMvector_free (freqs, 0);
  275. return nullptr;
  276. }
  277. ////////////////////////////////////////////////////////////////////////////////////////////
  278. // Classification - To TableOfReal //
  279. /////////////////////////////////////////////////////////////////////////////////////////////
  280. typedef struct {
  281. KNN me;
  282. PatternList ps;
  283. Categories uniqueCategories;
  284. TableOfReal output;
  285. FeatureWeights fws;
  286. integer k;
  287. int dist;
  288. integer istart;
  289. integer istop;
  290. } KNN_input_ToTableOfReal_t;
  291. autoTableOfReal KNN_classifyToTableOfReal
  292. (
  293. ///////////////////////////////
  294. // Parameters //
  295. ///////////////////////////////
  296. KNN me, // the classifier being used
  297. //
  298. PatternList ps, // source PatternList
  299. //
  300. FeatureWeights fws, // feature weights
  301. //
  302. integer k, // the number of sought after neighbours
  303. //
  304. int dist // distance weighting
  305. )
  306. {
  307. int nthreads = KNN_getNumberOfCPUs();
  308. integer chunksize = ps->ny / nthreads;
  309. autoCategories uniqueCategories = Categories_selectUniqueItems (my output.get());
  310. integer ncategories = uniqueCategories->size;
  311. Melder_assert (nthreads > 0);
  312. Melder_assert (ncategories > 0);
  313. Melder_assert (k > 0 && k <= my nInstances);
  314. if (! ncategories)
  315. return autoTableOfReal();
  316. if (chunksize < 1) {
  317. chunksize = 1;
  318. nthreads = ps -> ny;
  319. }
  320. integer istart = 1;
  321. integer istop = chunksize;
  322. KNN_input_ToTableOfReal_t ** input = (KNN_input_ToTableOfReal_t **) malloc (nthreads * sizeof (KNN_input_ToTableOfReal_t *));
  323. if (! input)
  324. return autoTableOfReal();
  325. autoTableOfReal output = TableOfReal_create (ps -> ny, ncategories);
  326. for (integer i = 1; i <= ncategories; i ++)
  327. TableOfReal_setColumnLabel (output.get(), i, uniqueCategories->at [i] -> string.get());
  328. for (int i = 0; i < nthreads; i ++) {
  329. input [i] = (KNN_input_ToTableOfReal_t *) malloc (sizeof (KNN_input_ToTableOfReal_t));
  330. if (! input [i]) {
  331. while (input [i --]) // ppgb FIXME: is always false
  332. free (input [i]); // ppgb FIXME: the stopping condition is incorrect
  333. free (input);
  334. return autoTableOfReal();
  335. }
  336. }
  337. for (int i = 0; i < nthreads; i ++) {
  338. input [i] -> me = me;
  339. input [i] -> ps = ps;
  340. input [i] -> output = output.get(); // YUCK: reference copy
  341. input [i] -> uniqueCategories = uniqueCategories.releaseToAmbiguousOwner();
  342. input [i] -> fws = fws;
  343. input [i] -> k = k;
  344. input [i] -> dist = dist;
  345. input [i] -> istart = istart;
  346. if (istop + chunksize > ps -> ny) {
  347. input [i] -> istop = ps -> ny;
  348. break;
  349. } else {
  350. input [i] -> istop = istop;
  351. istart = istop + 1;
  352. istop += chunksize;
  353. }
  354. }
  355. enum KNN_thread_status * error = (enum KNN_thread_status *) KNN_threadDistribution (KNN_classifyToTableOfRealAux, (void **) input, nthreads);
  356. for (int i = 0; i < nthreads; i ++)
  357. free (input [i]);
  358. free (input);
  359. if (error) // Something went very wrong, you ought to inform the user!
  360. {
  361. free (error);
  362. return autoTableOfReal();
  363. }
  364. return output;
  365. }
  366. void * KNN_classifyToTableOfRealAux
  367. (
  368. ///////////////////////////////
  369. // Parameters //
  370. ///////////////////////////////
  371. void * void_input
  372. )
  373. {
  374. KNN_input_ToTableOfReal_t *input = (KNN_input_ToTableOfReal_t *) void_input;
  375. integer ncategories = input -> uniqueCategories->size;
  376. autoNUMvector <integer> indices ((integer) 0, input -> k);
  377. autoNUMvector <double> distances ((integer) 0, input -> k);
  378. for (integer y = input -> istart; y <= input -> istop; y ++) {
  379. KNN_kNeighbours (input -> ps, input -> me -> input.get(), input -> fws, y, input -> k, indices.peek(), distances.peek());
  380. for (integer i = 0; i < input -> k; i ++) {
  381. for (integer j = 1; j <= ncategories; j ++) {
  382. if (FeatureWeights_areFriends (input -> me -> output->at [indices [i]], input -> uniqueCategories->at [j]))
  383. input -> output -> data [y] [j] += 1.0;
  384. }
  385. }
  386. }
  387. switch (input -> dist) {
  388. case kOla_DISTANCE_WEIGHTED_VOTING:
  389. for (integer y = input -> istart; y <= input -> istop; y ++) {
  390. longdouble sum = 0.0;
  391. for (integer c = 1; c <= ncategories; c ++) {
  392. input -> output -> data [y] [c] *= 1.0 / OlaMAX (distances [c], kOla_MINFLOAT);
  393. sum += input -> output -> data [y] [c];
  394. }
  395. for (integer c = 1; c <= ncategories; c ++)
  396. input -> output -> data [y] [c] /= sum;
  397. }
  398. break;
  399. case kOla_SQUARED_DISTANCE_WEIGHTED_VOTING:
  400. for (integer y = input -> istart; y <= input -> istop; y ++) {
  401. longdouble sum = 0.0;
  402. for (integer c = 1; c <= ncategories; c ++) {
  403. input -> output -> data [y] [c] *= 1.0 / OlaMAX (OlaSQUARE (distances [c]), kOla_MINFLOAT);
  404. sum += input -> output -> data [y] [c];
  405. }
  406. for (integer c = 1; c <= ncategories; c ++)
  407. input -> output -> data [y] [c] /= sum;
  408. }
  409. break;
  410. case kOla_FLAT_VOTING:
  411. for (integer y = input -> istart; y <= input -> istop; y ++) {
  412. longdouble sum = 0.0;
  413. for (integer c = 1; c <= ncategories; c ++)
  414. sum += input -> output -> data [y] [c];
  415. for (integer c = 1; c <= ncategories; c ++)
  416. input -> output -> data [y] [c] /= sum;
  417. }
  418. }
  419. return nullptr;
  420. }
  421. //////////////////////////////////////////////////////////////////////////////////////////////
  422. // Classification - Folding //
  423. /////////////////////////////////////////////////////////////////////////////////////////////
  424. autoCategories KNN_classifyFold
  425. (
  426. ///////////////////////////////
  427. // Parameters //
  428. ///////////////////////////////
  429. KNN me, // the classifier being used
  430. //
  431. PatternList ps, // source PatternList
  432. //
  433. FeatureWeights fws, // feature weights
  434. //
  435. integer k, // the number of sought after neighbours
  436. //
  437. int dist, // distance weighting
  438. //
  439. integer begin, // fold start, inclusive [...
  440. //
  441. integer end // fold end, inclusive ...]
  442. //
  443. )
  444. {
  445. Melder_assert(k > 0 && k <= ps->ny);
  446. Melder_assert(end > 0 && end <= ps->ny);
  447. Melder_assert(begin > 0 && begin <= ps->ny);
  448. if (begin > end)
  449. OlaSWAP(integer, begin, end);
  450. if (k > my nInstances - (end - begin))
  451. k = my nInstances - (end - begin);
  452. integer ncollected;
  453. integer ncategories;
  454. autoNUMvector <integer> indices ((integer) 0, k);
  455. autoNUMvector <integer> freqindices ((integer) 0, k);
  456. autoNUMvector <double> distances ((integer) 0, k);
  457. autoNUMvector <double> freqs ((integer) 0, k);
  458. autoNUMvector <integer> outputindices ((integer) 0, ps->ny);
  459. integer noutputindices = 0;
  460. for (integer y = begin; y <= end; y ++)
  461. {
  462. /////////////////////////////////////////
  463. // Localizing the k nearest neighbours //
  464. /////////////////////////////////////////
  465. ncollected = KNN_kNeighboursSkipRange (ps, my input.get(), fws, y, k, indices.peek(), distances.peek(), begin, end);
  466. /////////////////////////////////////////////////
  467. // Computing frequencies and average distances //
  468. /////////////////////////////////////////////////
  469. ncategories = KNN_kIndicesToFrequenciesAndDistances (my output.get(), k, indices.peek(), distances.peek(), freqs.peek(), freqindices.peek());
  470. ////////////////////////
  471. // Distance weighting //
  472. ////////////////////////
  473. switch (dist)
  474. {
  475. case kOla_DISTANCE_WEIGHTED_VOTING:
  476. for (integer c = 0; c < ncategories; c ++)
  477. freqs [c] *= 1.0 / OlaMAX (distances [c], kOla_MINFLOAT);
  478. break;
  479. case kOla_SQUARED_DISTANCE_WEIGHTED_VOTING:
  480. for (integer c = 0; c < ncategories; c ++)
  481. freqs [c] *= 1.0 / OlaMAX (OlaSQUARE (distances [c]), kOla_MINFLOAT);
  482. }
  483. KNN_normalizeFloatArray (freqs.peek(), ncategories);
  484. outputindices [noutputindices ++] = freqindices [KNN_max (freqs.peek(), ncategories)];
  485. }
  486. autoCategories output = Categories_create ();
  487. for (integer o = 0; o < noutputindices; o ++)
  488. output -> addItem_move (Data_copy (my output->at [outputindices [o]]));
  489. return output;
  490. }
  491. /////////////////////////////////////////////////////////////////////////////////////////////
  492. // Evaluation //
  493. /////////////////////////////////////////////////////////////////////////////////////////////
  494. double KNN_evaluate
  495. (
  496. ///////////////////////////////
  497. // Parameters //
  498. ///////////////////////////////
  499. KNN me, // the classifier being used
  500. //
  501. FeatureWeights fws, // feature weights
  502. //
  503. integer k, // the number of sought after neighbours
  504. //
  505. int dist, // distance weighting
  506. //
  507. int mode // kOla_TEN_FOLD_CROSS_VALIDATION / kOla_LEAVE_ONE_OUT
  508. //
  509. )
  510. {
  511. double correct = 0.0;
  512. integer adder;
  513. switch(mode)
  514. {
  515. case kOla_TEN_FOLD_CROSS_VALIDATION:
  516. adder = my nInstances / 10;
  517. break;
  518. case kOla_LEAVE_ONE_OUT:
  519. if (my nInstances > 1)
  520. adder = 1;
  521. else
  522. adder = 0;
  523. break;
  524. default:
  525. adder = 0;
  526. }
  527. if (adder == 0)
  528. return -1;
  529. for (integer begin = 1; begin <= my nInstances; begin += adder)
  530. {
  531. autoCategories c = KNN_classifyFold (me, my input.get(), fws, k, dist, begin, OlaMIN (begin + adder - 1, my nInstances));
  532. for (integer y = 1; y <= c->size; y ++)
  533. if (FeatureWeights_areFriends (c->at [y], my output->at [begin + y - 1]))
  534. correct += 1.0;
  535. }
  536. correct /= (double) my nInstances;
  537. return correct;
  538. }
  539. /////////////////////////////////////////////////////////////////////////////////////////////
  540. // Evaluation using a separate test set //
  541. /////////////////////////////////////////////////////////////////////////////////////////////
  542. double KNN_evaluateWithTestSet
  543. (
  544. ///////////////////////////////
  545. // Parameters //
  546. ///////////////////////////////
  547. KNN me, // the classifier being used
  548. //
  549. PatternList p, // The vectors of the test set
  550. //
  551. Categories c, // The categories of the test set
  552. //
  553. FeatureWeights fws, // feature weights
  554. //
  555. integer k, // the number of sought after neighbours
  556. //
  557. int dist // distance weighting
  558. //
  559. )
  560. {
  561. double correct = 0.0;
  562. autoCategories t = KNN_classifyToCategories (me, p, fws, k, dist);
  563. for (integer y = 1; y <= t->size; y ++)
  564. if (FeatureWeights_areFriends (t->at [y], c->at [y])) correct += 1.0;
  565. return correct / c->size;
  566. }
  567. /////////////////////////////////////////////////////////////////////////////////////////////
  568. // Model search //
  569. /////////////////////////////////////////////////////////////////////////////////////////////
  570. typedef struct structsoil
  571. {
  572. double performance;
  573. integer dist;
  574. integer k;
  575. } soil;
  576. double KNN_modelSearch
  577. (
  578. ///////////////////////////////
  579. // Parameters //
  580. ///////////////////////////////
  581. KNN me, // the classifier being used
  582. //
  583. FeatureWeights fws, // feature weights
  584. //
  585. integer * k, // valid integer *, to hold the output value of k
  586. //
  587. int * dist, // valid int *, to hold the output value dist_weight
  588. //
  589. int mode, // evaluation mode
  590. //
  591. double rate, // learning rate
  592. //
  593. integer nseeds // the number of seeds to be used
  594. //
  595. )
  596. {
  597. try {
  598. int dists[] = {
  599. kOla_SQUARED_DISTANCE_WEIGHTED_VOTING,
  600. kOla_DISTANCE_WEIGHTED_VOTING,
  601. kOla_FLAT_VOTING
  602. };
  603. integer max = *k;
  604. double range = (double) max / 2.0;
  605. double pivot = range;
  606. double dpivot = 1.0;
  607. double drange = 1.0;
  608. double drate = rate / range;
  609. soil best = { 0, Melder_iround (dpivot), Melder_iround (dpivot) };
  610. autoNUMvector <soil> field ((integer) 0, nseeds - 1);
  611. while (range > 0) {
  612. for (integer n = 0; n < nseeds; n++) {
  613. field[n].k = Melder_iround (NUMrandomUniform (OlaMAX (pivot - range, 1), OlaMIN (pivot + range, max)));
  614. field[n].dist = Melder_iround (NUMrandomUniform (OlaMAX (dpivot - drange, 0), OlaMIN (dpivot + drange, 2)));
  615. field[n].performance = KNN_evaluate (me, fws, field[n].k, dists[field[n].dist], mode);
  616. }
  617. integer maxindex = 0;
  618. for (integer n = 1; n < nseeds; n ++)
  619. if (field [n]. performance > field [maxindex]. performance) maxindex = n;
  620. if (field [maxindex]. performance > best. performance) {
  621. pivot = field[maxindex].k;
  622. dpivot = field[maxindex].dist;
  623. best.performance = field[maxindex].performance;
  624. best.dist = field[maxindex].dist;
  625. best.k = field[maxindex].k;
  626. }
  627. range -= rate;
  628. drange -= drate;
  629. }
  630. *k = best.k;
  631. *dist = dists[best.dist];
  632. return best.performance;
  633. } catch (MelderError) {
  634. Melder_throw (me, U" & ", fws, U": model search not performed.");
  635. }
  636. }
  637. /////////////////////////////////////////////////////////////////////////////////////////////
  638. // Euclidean distance //
  639. /////////////////////////////////////////////////////////////////////////////////////////////
  640. double KNN_distanceEuclidean
  641. (
  642. PatternList ps, // PatternList 1
  643. //
  644. PatternList pt, // PatternList 2
  645. //
  646. FeatureWeights fws, // Feature weights
  647. //
  648. integer rows, // Vector index of pattern 1
  649. //
  650. integer rowt // Vector index of pattern 2
  651. )
  652. {
  653. double distance = 0.0;
  654. for (integer x = 1; x <= ps->nx; ++x)
  655. distance += OlaSQUARE ((ps->z[rows][x] - pt->z[rowt][x]) * fws->fweights->data[1][x]);
  656. return sqrt (distance);
  657. }
  658. /////////////////////////////////////////////////////////////////////////////////////////////
  659. // Manhattan distance //
  660. /////////////////////////////////////////////////////////////////////////////////////////////
  661. double KNN_distanceManhattan
  662. (
  663. PatternList ps, // PatternList 1
  664. //
  665. PatternList pt, // PatternList 2
  666. //
  667. integer rows, // Vector index of pattern 1
  668. //
  669. integer rowt // Vector index of pattern 2
  670. //
  671. )
  672. {
  673. longdouble distance = 0.0;
  674. for (integer x = 1; x <= ps->nx; x ++)
  675. distance += fabs (ps->z[rows][x] - pt->z[rowt][x]);
  676. return (double) distance;
  677. }
  678. /////////////////////////////////////////////////////////////////////////////////////////////
  679. // Find longest distance //
  680. /////////////////////////////////////////////////////////////////////////////////////////////
  681. integer KNN_max
  682. (
  683. double * distances, // an array of distances containing ...
  684. //
  685. integer ndistances // ndistances distances
  686. //
  687. )
  688. {
  689. integer maxndx = 0;
  690. for (integer maxc = 1; maxc < ndistances; maxc ++) {
  691. if (distances[maxc] > distances[maxndx])
  692. maxndx = maxc;
  693. }
  694. return maxndx;
  695. }
  696. ////////////////////////////////////////////////////////////////////////////////////////////
  697. // Locate k neighbours, skip one + disposal of distance //
  698. /////////////////////////////////////////////////////////////////////////////////////////////
  699. integer KNN_kNeighboursSkip
  700. (
  701. ///////////////////////////////
  702. // Parameters //
  703. ///////////////////////////////
  704. PatternList j, // source pattern
  705. //
  706. PatternList p, // target pattern (where neighbours are sought for)
  707. //
  708. FeatureWeights fws, // feature weights
  709. //
  710. integer jy, // source instance index
  711. //
  712. integer k, // the number of sought after neighbours
  713. //
  714. integer * indices, // memory space to contain the indices of
  715. // the k neighbours
  716. //
  717. integer skipper // the index of the instance to be skipped
  718. //
  719. )
  720. {
  721. integer maxi;
  722. integer dc = 0;
  723. integer py = 1;
  724. autoNUMvector <double> distances ((integer) 0, k - 1);
  725. Melder_assert (jy > 0 && jy <= j -> ny);
  726. Melder_assert (k > 0 && k <= p -> ny);
  727. Melder_assert (skipper <= p -> ny);
  728. while (dc < k && py <= p -> ny) {
  729. if (py != jy && py != skipper) {
  730. distances [dc] = KNN_distanceEuclidean (j, p, fws, jy, py);
  731. indices [dc] = py;
  732. ++ dc;
  733. }
  734. ++ py;
  735. }
  736. maxi = KNN_max (distances.peek(), k);
  737. while (py <= p -> ny) {
  738. if (py != jy && py != skipper) {
  739. double d = KNN_distanceEuclidean (j, p, fws, jy, py);
  740. if (d < distances [maxi]) {
  741. distances [maxi] = d;
  742. indices [maxi] = py;
  743. maxi = KNN_max (distances.peek(), k);
  744. }
  745. }
  746. ++ py;
  747. }
  748. return OlaMIN (k, dc);
  749. }
  750. //////////////////////////////////////////////////////////////////////////////////
  751. // Locate the k nearest neighbours, exclude instances within the range defined //
  752. // by [begin ... end] //
  753. //////////////////////////////////////////////////////////////////////////////////
  754. integer KNN_kNeighboursSkipRange
  755. (
  756. ///////////////////////////////
  757. // Parameters //
  758. ///////////////////////////////
  759. PatternList j, // source-pattern (where the unknown is located)
  760. //
  761. PatternList p, // target pattern (where neighbours are sought for)
  762. //
  763. FeatureWeights fws, // feature weights
  764. //
  765. integer jy, // the index of the unknown instance in the source pattern
  766. //
  767. integer k, // the number of sought after neighbours
  768. //
  769. integer * indices, // a pointer to a memory-space big enough for k integers
  770. // representing indices to the k neighbours in the
  771. // target pattern
  772. //
  773. double * distances, // a pointer to a memory-space big enough for k
  774. // doubles representing the distances to the k
  775. // neighbours
  776. //
  777. integer begin, // an index indicating the first instance in the
  778. // target pattern to be excluded from the search
  779. //
  780. integer end // an index indicating the last instance in the
  781. // range of excluded instances in the target
  782. // pattern
  783. )
  784. {
  785. ///////////////////////////////
  786. // Private variables //
  787. ///////////////////////////////
  788. integer maxi; // index indicating the most distant neighbour
  789. // among the k nearest
  790. //
  791. integer dc = 0; // fetch counter
  792. //
  793. integer py = 0; //
  794. Melder_assert (jy > 0 && jy <= j->ny);
  795. Melder_assert (k > 0 && k <= p->ny);
  796. Melder_assert (end > 0 && end <= j->ny);
  797. Melder_assert (begin > 0 && begin <= j->ny);
  798. while (dc < k && (end + py) % p->ny + 1 != begin) { // the first k neighbours are the nearest found so far
  799. if ((end + py) % p->ny + 1 != jy) { // no instance is its own neighbour
  800. distances [dc] = KNN_distanceEuclidean (j, p, fws, jy, (end + py) % p->ny + 1);
  801. indices [dc] = (end + py) % p->ny + 1;
  802. ++ dc;
  803. }
  804. ++ py;
  805. }
  806. maxi = KNN_max(distances, k); // accept only those instances less distant
  807. while ((end + py) % p->ny + 1 != begin) { // than the least near one found this far
  808. if ((end + py) % p->ny + 1 != jy) {
  809. double d = KNN_distanceEuclidean(j, p, fws, jy, (end + py) % p->ny + 1);
  810. if (d < distances [maxi]) {
  811. distances [maxi] = d;
  812. indices [maxi] = (end + py) % p->ny + 1;
  813. maxi = KNN_max (distances, k);
  814. }
  815. }
  816. ++ py;
  817. }
  818. return OlaMIN (k, dc); // return the number of found neighbours
  819. }
  820. /////////////////////////////////////////////////////////////////////////////////////////////
  821. // Locate k neighbours //
  822. /////////////////////////////////////////////////////////////////////////////////////////////
  823. integer KNN_kNeighbours
  824. (
  825. ///////////////////////////////
  826. // Parameters //
  827. ///////////////////////////////
  828. PatternList j, // source-pattern (where the unknown is located)
  829. //
  830. PatternList p, // target pattern (where neighbours are sought for)
  831. //
  832. FeatureWeights fws, // feature weights
  833. //
  834. integer jy, // the index of the unknown instance in the source pattern
  835. //
  836. integer k, // the number of sought after neighbours
  837. //
  838. integer * indices, // a pointer to a memory-space big enough for k integers
  839. // representing indices to the k neighbours in the
  840. // target pattern
  841. double * distances // a pointer to a memory-space big enough for k
  842. // doubles representing the distances to the k
  843. // neighbours
  844. //
  845. )
  846. {
  847. integer maxi;
  848. integer dc = 0;
  849. integer py = 1;
  850. Melder_assert (jy > 0 && jy <= j -> ny);
  851. Melder_assert (k > 0 && k <= p -> ny);
  852. Melder_assert (indices);
  853. Melder_assert (distances);
  854. while (dc < k && py <= p->ny) {
  855. if (py != jy) {
  856. distances[dc] = KNN_distanceEuclidean (j, p, fws, jy, py);
  857. indices[dc] = py;
  858. ++ dc;
  859. }
  860. ++ py;
  861. }
  862. maxi = KNN_max(distances, k);
  863. while (py <= p -> ny) {
  864. if (py != jy) {
  865. double d = KNN_distanceEuclidean (j, p, fws, jy, py);
  866. if (d < distances [maxi]) {
  867. distances [maxi] = d;
  868. indices [maxi] = py;
  869. maxi = KNN_max (distances, k);
  870. }
  871. }
  872. ++ py;
  873. }
  874. integer ret = OlaMIN (k, dc);
  875. if (ret < 1) {
  876. indices [0] = jy;
  877. return 0;
  878. }
  879. else
  880. return ret;
  881. }
  882. /////////////////////////////////////////////////////////////////////////////////////////////
  883. // Locating k (nearest) friends //
  884. /////////////////////////////////////////////////////////////////////////////////////////////
  885. integer KNN_kFriends
  886. (
  887. ///////////////////////////////
  888. // Parameters //
  889. ///////////////////////////////
  890. PatternList j, // source pattern
  891. //
  892. PatternList p, // target pattern (where friends are sought for)
  893. //
  894. Categories c, // categories
  895. //
  896. integer jy, // the index of the source instance
  897. //
  898. integer k, // the number of sought after friends
  899. //
  900. integer * indices // a pointer to a memory-space big enough for k integers
  901. // representing indices to the k friends in the
  902. // target pattern
  903. )
  904. {
  905. integer maxi;
  906. integer dc = 0;
  907. integer py = 1;
  908. autoNUMvector <double> distances ((integer) 0, k - 1);
  909. Melder_assert (jy <= j -> ny && k <= p -> ny && k > 0);
  910. Melder_assert (indices);
  911. while (dc < k && py < p -> ny) {
  912. if (jy != py && FeatureWeights_areFriends (c->at [jy], c->at [py])) {
  913. distances[dc] = KNN_distanceManhattan (j, p, jy, py);
  914. indices[dc] = py;
  915. dc ++;
  916. }
  917. ++ py;
  918. }
  919. maxi = KNN_max (distances.peek(), k);
  920. while (py <= p -> ny) {
  921. if (jy != py && FeatureWeights_areFriends (c->at [jy], c->at [py])) {
  922. double d = KNN_distanceManhattan (j, p, jy, py);
  923. if (d < distances [maxi]) {
  924. distances [maxi] = d;
  925. indices [maxi] = py;
  926. maxi = KNN_max (distances.peek(), k);
  927. }
  928. }
  929. ++ py;
  930. }
  931. return OlaMIN (k, dc);
  932. }
  933. /////////////////////////////////////////////////////////////////////////////////////////////
  934. // Computing the distance to the nearest enemy //
  935. /////////////////////////////////////////////////////////////////////////////////////////////
  936. double KNN_nearestEnemy
  937. (
  938. ///////////////////////////////
  939. // Parameters //
  940. ///////////////////////////////
  941. PatternList j, // source-pattern
  942. //
  943. PatternList p, // target pattern (where friends are sought for)
  944. //
  945. Categories c, // categories
  946. //
  947. integer jy // the index of the source instance
  948. //
  949. )
  950. {
  951. double distance = KNN_distanceManhattan(j, p, jy, 1);
  952. Melder_assert(jy > 0 && jy <= j->ny );
  953. for (integer y = 2; y <= p->ny; y++)
  954. {
  955. if (FeatureWeights_areEnemies (c->at [jy], c->at [y])) {
  956. double newdist = KNN_distanceManhattan(j, p, jy, y);
  957. if (newdist > distance)
  958. distance = newdist;
  959. }
  960. }
  961. return distance;
  962. }
  963. /////////////////////////////////////////////////////////////////////////////////////////////
  964. // Computing the number of friends among k neighbours //
  965. /////////////////////////////////////////////////////////////////////////////////////////////
  966. integer KNN_friendsAmongkNeighbours
  967. (
  968. ///////////////////////////////
  969. // Parameters //
  970. ///////////////////////////////
  971. PatternList j, // source-pattern
  972. //
  973. PatternList p, // target pattern (where friends are sought for)
  974. //
  975. Categories c, // categories
  976. //
  977. integer jy, // the index of the source instance
  978. //
  979. integer k // k (!)
  980. //
  981. )
  982. {
  983. autoNUMvector <double> distances ((integer) 0, k - 1);
  984. autoNUMvector <integer> indices ((integer) 0, k - 1);
  985. integer friends = 0;
  986. Melder_assert (jy > 0 && jy <= j->ny && k <= p->ny && k > 0);
  987. autoFeatureWeights fws = FeatureWeights_create (p -> nx);
  988. integer ncollected = KNN_kNeighbours (j, p, fws.get(), jy, k, indices.peek(), distances.peek());
  989. while (ncollected --) {
  990. if (FeatureWeights_areFriends (c->at [jy], c->at [indices [ncollected]]))
  991. friends ++;
  992. }
  993. return friends;
  994. }
  995. /////////////////////////////////////////////////////////////////////////////////////////////
  996. // Locating k unique (nearest) enemies //
  997. /////////////////////////////////////////////////////////////////////////////////////////////
  998. integer KNN_kUniqueEnemies
  999. (
  1000. ///////////////////////////////
  1001. // Parameters //
  1002. ///////////////////////////////
  1003. PatternList j, // source-pattern
  1004. //
  1005. PatternList p, // target pattern (where friends are sought for)
  1006. //
  1007. Categories c, // categories
  1008. //
  1009. integer jy, // the index of the source instance
  1010. //
  1011. integer k, // k (!)
  1012. //
  1013. integer *indices // a memory space to hold the indices of the
  1014. // located enemies
  1015. //
  1016. )
  1017. {
  1018. integer maxi;
  1019. integer dc = 0;
  1020. integer py = 1;
  1021. double *distances = NUMvector <double> (0, k - 1);
  1022. Melder_assert (jy <= j->ny);
  1023. Melder_assert (k <= p->ny);
  1024. Melder_assert (k > 0);
  1025. Melder_assert (indices);
  1026. while (dc < k && py <= p -> ny) {
  1027. if (FeatureWeights_areEnemies (c->at [jy], c->at [py])) {
  1028. int hasfriend = 0;
  1029. for (integer sc = 0; sc < dc; ++sc) {
  1030. if (FeatureWeights_areFriends (c->at [py], c->at [indices [sc]]))
  1031. hasfriend = 1;
  1032. }
  1033. if (!hasfriend) {
  1034. distances[dc] = KNN_distanceManhattan(j, p, jy, py);
  1035. indices[dc] = py;
  1036. ++ dc;
  1037. }
  1038. }
  1039. ++ py;
  1040. }
  1041. maxi = KNN_max(distances, k);
  1042. while (py <= p->ny) {
  1043. if (FeatureWeights_areEnemies (c->at [jy], c->at [py])) {
  1044. int hasfriend = 0;
  1045. for (integer sc = 0; sc < dc; ++sc) {
  1046. if (FeatureWeights_areFriends (c->at [py], c->at [indices[sc]]))
  1047. hasfriend = 1;
  1048. }
  1049. if (! hasfriend) {
  1050. double d = KNN_distanceManhattan(j, p, jy, py);
  1051. if (d < distances[maxi] && FeatureWeights_areFriends (c->at [jy], c->at [py])) {
  1052. distances[maxi] = d;
  1053. indices[maxi] = py;
  1054. maxi = KNN_max(distances, k);
  1055. }
  1056. }
  1057. }
  1058. ++ py;
  1059. }
  1060. NUMvector_free (distances, 0);
  1061. return OlaMIN (k, dc);
  1062. }
  1063. /////////////////////////////////////////////////////////////////////////////////////////////
  1064. // Compute dissimilarity matrix //
  1065. /////////////////////////////////////////////////////////////////////////////////////////////
  1066. autoDissimilarity KNN_patternToDissimilarity
  1067. (
  1068. ///////////////////////////////
  1069. // Parameters //
  1070. ///////////////////////////////
  1071. PatternList p, // PatternList
  1072. //
  1073. FeatureWeights fws // Feature weights
  1074. //
  1075. )
  1076. {
  1077. autoDissimilarity output = Dissimilarity_create (p -> ny);
  1078. for (integer y = 1; y <= p -> ny; ++ y)
  1079. for (integer x = 1; x <= p -> ny; ++ x)
  1080. output -> data [y] [x] = KNN_distanceEuclidean (p, p, fws, y, x);
  1081. return output;
  1082. }
  1083. /////////////////////////////////////////////////////////////////////////////////////////////
  1084. // Compute frequencies //
  1085. /////////////////////////////////////////////////////////////////////////////////////////////
  1086. integer KNN_kIndicesToFrequenciesAndDistances
  1087. (
  1088. ///////////////////////////////
  1089. // Parameters //
  1090. ///////////////////////////////
  1091. Categories c, // Source categories
  1092. //
  1093. integer k, // k (!)
  1094. //
  1095. integer * indices, // In: indices
  1096. //
  1097. double * distances, // Out: distances
  1098. //
  1099. double * freqs, // Out: and frequencies (double, sic!)
  1100. //
  1101. integer *freqindices // Out: and indices -> freqs.
  1102. )
  1103. {
  1104. integer ncategories = 0;
  1105. Melder_assert (k <= c->size && k > 0);
  1106. Melder_assert (distances && indices && freqs && freqindices);
  1107. for (integer y = 0; y < k; ++y)
  1108. {
  1109. int hasfriend = 0;
  1110. integer ifriend = 0;
  1111. while (ifriend < ncategories)
  1112. {
  1113. if (FeatureWeights_areFriends (c->at [indices [y]], c->at [freqindices [ifriend]]))
  1114. {
  1115. hasfriend = 1;
  1116. break;
  1117. }
  1118. ++ifriend;
  1119. }
  1120. if (!hasfriend)
  1121. {
  1122. freqindices[ncategories] = indices[y];
  1123. freqs[ncategories] = 1;
  1124. distances[ncategories] = distances[y];
  1125. ncategories++;
  1126. }
  1127. else
  1128. {
  1129. ++freqs[ifriend];
  1130. distances[ifriend] += (distances[y] - distances[ifriend]) / (ncategories + 1);
  1131. }
  1132. }
  1133. return(ncategories);
  1134. }
  1135. /////////////////////////////////////////////////////////////////////////////////////////////
  1136. // Normalize array //
  1137. /////////////////////////////////////////////////////////////////////////////////////////////
  1138. void KNN_normalizeFloatArray
  1139. (
  1140. ///////////////////////////////
  1141. // Parameters //
  1142. ///////////////////////////////
  1143. double * array, // Array to be normalized
  1144. //
  1145. integer n // The number of elements
  1146. // in the array
  1147. )
  1148. {
  1149. integer c = 0;
  1150. longdouble sum = 0.0;
  1151. while (c < n)
  1152. sum += array [c ++]; // this sums over array [0 .. n-1]
  1153. while (c --)
  1154. array [c] /= sum; // this scales array [0 .. n-1]
  1155. }
  1156. /////////////////////////////////////////////////////////////////////////////////////////////
  1157. // Remove instance //
  1158. /////////////////////////////////////////////////////////////////////////////////////////////
  1159. void KNN_removeInstance
  1160. (
  1161. ///////////////////////////////
  1162. // Parameters //
  1163. ///////////////////////////////
  1164. KNN me, // Classifier
  1165. //
  1166. integer y // Index of the instance to be purged
  1167. //
  1168. )
  1169. {
  1170. if (y == 1 && my nInstances == 1) {
  1171. my nInstances = 0;
  1172. my input.reset();
  1173. my output.reset();
  1174. return;
  1175. }
  1176. Melder_assert (y > 0 && y <= my nInstances);
  1177. autoPatternList newPattern = PatternList_create (my nInstances - 1, my input -> nx);
  1178. integer yt = 1;
  1179. for (integer cy = 1; cy <= my nInstances; cy ++) {
  1180. if (cy != y) {
  1181. for (integer cx = 1; cx <= my input -> nx; cx ++)
  1182. newPattern -> z [yt] [cx] = my input -> z [cy] [cx];
  1183. yt ++;
  1184. }
  1185. }
  1186. my input = newPattern.move();
  1187. my output -> removeItem (y);
  1188. my nInstances--;
  1189. }
  1190. /////////////////////////////////////////////////////////////////////////////////////////////
  1191. // Shuffle instances //
  1192. /////////////////////////////////////////////////////////////////////////////////////////////
  1193. void KNN_shuffleInstances
  1194. (
  1195. ///////////////////////////////
  1196. // Parameters //
  1197. ///////////////////////////////
  1198. KNN me // Classifier whose instance
  1199. // base is to be shuffled
  1200. )
  1201. {
  1202. if (my nInstances < 2)
  1203. return; // it takes at least two to tango
  1204. autoPatternList new_input = PatternList_create (my nInstances, my input -> nx);
  1205. autoCategories new_output = Categories_create ();
  1206. integer y = 1;
  1207. while (my nInstances)
  1208. {
  1209. integer pick = NUMrandomInteger (1, my nInstances);
  1210. new_output -> addItem_move (Data_copy (my output->at [pick]));
  1211. for (integer x = 1; x <= my input -> nx; x ++)
  1212. new_input -> z [y] [x] = my input -> z [pick] [x];
  1213. KNN_removeInstance (me, pick);
  1214. y ++;
  1215. }
  1216. my nInstances = new_output->size;
  1217. my input = std::move (new_input);
  1218. my output = std::move (new_output);
  1219. }
  1220. /////////////////////////////////////////////////////////////////////////////////////////////
  1221. // KNN to Permutation (experimental) //
  1222. /////////////////////////////////////////////////////////////////////////////////////////////
  1223. autoPermutation KNN_SA_ToPermutation
  1224. (
  1225. ///////////////////////////////
  1226. // Parameters //
  1227. ///////////////////////////////
  1228. KNN me, // the classifier being used
  1229. //
  1230. integer tries, //
  1231. //
  1232. integer iterations, //
  1233. //
  1234. double step_size, //
  1235. //
  1236. double boltzmann_c, //
  1237. //
  1238. double temp_start, //
  1239. //
  1240. double damping_f, //
  1241. //
  1242. double temp_stop //
  1243. //
  1244. )
  1245. {
  1246. gsl_rng * r;
  1247. const gsl_rng_type * T;
  1248. KNN_SA_t * istruct = KNN_SA_t_create (my input.get());
  1249. autoPermutation result = Permutation_create (my nInstances);
  1250. gsl_siman_params_t params = { (int) tries, (int) iterations, step_size, boltzmann_c, temp_start, damping_f, temp_stop};
  1251. gsl_rng_env_setup();
  1252. T = gsl_rng_default;
  1253. r = gsl_rng_alloc(T);
  1254. gsl_siman_solve(r,
  1255. istruct,
  1256. KNN_SA_t_energy,
  1257. KNN_SA_t_step,
  1258. KNN_SA_t_metric,
  1259. nullptr, // KNN_SA_t_print
  1260. KNN_SA_t_copy,
  1261. KNN_SA_t_copy_construct,
  1262. KNN_SA_t_destroy,
  1263. 0,
  1264. params);
  1265. for (integer i = 1; i <= my nInstances; ++i)
  1266. result->p[i] = istruct->indices[i];
  1267. KNN_SA_t_destroy(istruct);
  1268. return result;
  1269. }
  1270. double KNN_SA_t_energy
  1271. (
  1272. ///////////////////////////////
  1273. // Parameters //
  1274. ///////////////////////////////
  1275. void * istruct
  1276. )
  1277. {
  1278. if(((KNN_SA_t *) istruct)->p->ny < 2)
  1279. return(0);
  1280. double eCost = 0;
  1281. for (integer i = 1; i <= ((KNN_SA_t *) istruct)->p->ny; ++i)
  1282. {
  1283. /* fast and sloppy version */
  1284. double jDist = 0;
  1285. double kDist = 0;
  1286. integer j = i - 1 > 0 ? i - 1 : ((KNN_SA_t *) istruct)->p->ny;
  1287. integer k = i + 1 <= ((KNN_SA_t *) istruct)->p->ny ? i + 1 : 1;
  1288. for (integer x = 1; x <= ((KNN_SA_t *) istruct)->p->nx; ++x)
  1289. {
  1290. jDist += OlaSQUARE(((KNN_SA_t *) istruct)->p->z[((KNN_SA_t *) istruct)->indices[i]][x] -
  1291. ((KNN_SA_t *) istruct)->p->z[((KNN_SA_t *) istruct)->indices[j]][x]);
  1292. kDist += OlaSQUARE(((KNN_SA_t *) istruct)->p->z[((KNN_SA_t *) istruct)->indices[i]][x] -
  1293. ((KNN_SA_t *) istruct)->p->z[((KNN_SA_t *) istruct)->indices[k]][x]);
  1294. }
  1295. eCost += ((sqrt(jDist) + sqrt(kDist)) / 2 - eCost) / i;
  1296. }
  1297. return(eCost);
  1298. }
  1299. double KNN_SA_t_metric
  1300. (
  1301. void * istruct1,
  1302. void * istruct2
  1303. )
  1304. {
  1305. double result = 0;
  1306. for (integer i = ((KNN_SA_t *) istruct1)->p->ny; i >= 1; --i)
  1307. if(((KNN_SA_t *) istruct1)->indices[i] != ((KNN_SA_t *) istruct2)->indices[i])
  1308. ++result;
  1309. return result;
  1310. }
  1311. void KNN_SA_t_print (void * istruct) {
  1312. Melder_casual (U"\n");
  1313. for (integer i = 1; i <= ((KNN_SA_t *) istruct) -> p -> ny; i ++)
  1314. Melder_casual (((KNN_SA_t *) istruct) -> indices [i]);
  1315. Melder_casual (U"\n");
  1316. }
  1317. void KNN_SA_t_step
  1318. (
  1319. const gsl_rng * r,
  1320. void * istruct,
  1321. double step_size
  1322. )
  1323. {
  1324. integer i1 = Melder_iround ((((KNN_SA_t *) istruct) -> p -> ny - 1) * gsl_rng_uniform (r)) + 1;
  1325. integer i2 = (i1 + Melder_iround (step_size * gsl_rng_uniform (r))) % ((KNN_SA_t *) istruct) -> p -> ny + 1;
  1326. if (i1 == i2)
  1327. return;
  1328. if (i1 > i2)
  1329. OlaSWAP (integer, i1, i2);
  1330. integer partitions[i2 - i1 + 1];
  1331. KNN_SA_partition(((KNN_SA_t *) istruct)->p, i1, i2, partitions);
  1332. for (integer r, l = 1, stop = i2 - i1 + 1; l < stop; l ++)
  1333. {
  1334. while (l < stop && partitions [l] == 1)
  1335. l ++;
  1336. r = l + 1;
  1337. while (r <= stop && partitions [r] == 2)
  1338. r ++;
  1339. if (r == stop)
  1340. break;
  1341. OlaSWAP (integer, ((KNN_SA_t *) istruct) -> indices [i1], ((KNN_SA_t *) istruct) -> indices [i2]);
  1342. }
  1343. }
  1344. void KNN_SA_t_copy
  1345. (
  1346. void * istruct_src,
  1347. void * istruct_dest
  1348. )
  1349. {
  1350. ((KNN_SA_t *) istruct_dest)->p = ((KNN_SA_t *) istruct_src)->p;
  1351. for (integer i = 1; i <= ((KNN_SA_t *) istruct_dest)->p->ny; ++i)
  1352. ((KNN_SA_t *) istruct_dest)->indices[i] = ((KNN_SA_t *) istruct_src)->indices[i];
  1353. }
  1354. void * KNN_SA_t_copy_construct
  1355. (
  1356. void * istruct
  1357. )
  1358. {
  1359. KNN_SA_t * result = (KNN_SA_t *) malloc(sizeof(KNN_SA_t));
  1360. result->p = ((KNN_SA_t *) istruct)->p;
  1361. result->indices = (integer *) malloc (sizeof(integer) * (result->p->ny + 1));
  1362. for (integer i = 1; i <= result->p->ny; ++i)
  1363. result->indices[i] = ((KNN_SA_t *) istruct)->indices[i];
  1364. return((void *) result);
  1365. }
  1366. KNN_SA_t * KNN_SA_t_create
  1367. (
  1368. PatternList p
  1369. )
  1370. {
  1371. KNN_SA_t * result = (KNN_SA_t *) malloc(sizeof(KNN_SA_t));
  1372. result->p = p;
  1373. result->indices = (integer *) malloc (sizeof (integer) * (p->ny + 1));
  1374. for (integer i = 1; i <= p->ny; ++i)
  1375. result->indices[i] = i;
  1376. return(result);
  1377. }
  1378. void KNN_SA_t_destroy
  1379. (
  1380. void * istruct
  1381. )
  1382. {
  1383. free(((KNN_SA_t *) istruct)->indices);
  1384. free((KNN_SA_t *) istruct);
  1385. }
  1386. void KNN_SA_partition
  1387. (
  1388. ///////////////////////////////
  1389. // Parameters //
  1390. ///////////////////////////////
  1391. PatternList p, //
  1392. //
  1393. integer i1, // i1 < i2
  1394. //
  1395. integer i2, //
  1396. //
  1397. integer * result // [0] not used
  1398. //
  1399. )
  1400. {
  1401. integer c1 = Melder_iround (NUMrandomUniform (i1, i2)); // BUG: probably incorrect (the edges have half-probability)
  1402. integer c2 = Melder_iround (NUMrandomUniform (i1, i2));
  1403. double *p1 = NUMvector <double> (1, p->nx);
  1404. double *p2 = NUMvector <double> (1, p->nx);
  1405. for (integer x = 1; x <= p->nx; ++x)
  1406. {
  1407. p1[x] = p->z[c1][x];
  1408. p2[x] = p->z[c2][x];
  1409. }
  1410. for (bool converging = true; converging; )
  1411. {
  1412. double d1, d2;
  1413. converging = false;
  1414. for (integer i = i1, j = 1; i <= i2; ++i)
  1415. {
  1416. d1 = 0.0;
  1417. d2 = 0.0;
  1418. for (integer x = 1; x <= p->nx; ++x)
  1419. {
  1420. d1 += OlaSQUARE(p->z[i][x] - p1[x]);
  1421. d2 += OlaSQUARE(p->z[i][x] - p2[x]);
  1422. }
  1423. d1 = sqrt (d1);
  1424. d2 = sqrt (d2);
  1425. if (d1 < d2)
  1426. {
  1427. if (result [j] != 1)
  1428. {
  1429. converging = true;
  1430. result [j] = 1;
  1431. }
  1432. }
  1433. else
  1434. {
  1435. if (result [j] != 2)
  1436. {
  1437. converging = true;
  1438. result [j] = 2;
  1439. }
  1440. }
  1441. j ++;
  1442. }
  1443. for (integer x = 1; x <= p -> nx; x ++)
  1444. {
  1445. p1 [x] = 0.0;
  1446. p2 [x] = 0.0;
  1447. }
  1448. for (integer i = i1, j = 1, j1 = 1, j2 = 1; i <= i2; i ++)
  1449. {
  1450. if (result [j] == 1)
  1451. {
  1452. for (integer x = 1; x <= p->nx; x ++)
  1453. p1[x] += (p->z[i][x] - p1[x]) / j1;
  1454. j1 ++;
  1455. }
  1456. else
  1457. {
  1458. for (integer x = 1; x <= p -> nx; x ++)
  1459. p2[x] += (p->z[i][x] - p2[x]) / j2;
  1460. j2 ++;
  1461. }
  1462. j ++;
  1463. }
  1464. }
  1465. NUMvector_free (p1, 1);
  1466. NUMvector_free (p2, 1);
  1467. }
  1468. /* End of file KNN.cpp */