fern.h 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. /* Code for making/predicting by single fern
  2. Copyright 2011-2018 Miron B. Kursa
  3. This file is part of rFerns R package.
  4. rFerns is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
  5. rFerns is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
  6. You should have received a copy of the GNU General Public License along with rFerns. If not, see http://www.gnu.org/licenses/.
  7. */
  8. void makeFern(DATASET_,FERN_,uint *restrict bag,score_t *restrict oobPrMatrix,uint *restrict idx,SIMP_,R_){
  9. for(uint e=0;e<N;e++) idx[e]=0;
  10. for(uint e=0;e<D;e++){
  11. //Select an attribute to make a split on
  12. uint E=splitAtts[e]=RINDEX(nX);
  13. switch(X[E].numCat){
  14. case 0:{
  15. //Make numerical split
  16. double *restrict x=(double*)(X[E].x);
  17. double threshold=.5*(x[RINDEX(N)]+x[RINDEX(N)]);
  18. for(uint ee=0;ee<N;ee++)
  19. idx[ee]+=(1<<e)*(x[ee]<threshold);
  20. thresholds[e].value=threshold;
  21. break;
  22. }
  23. case -1:{
  24. //Make integer split
  25. sint *restrict x=(sint*)(X[E].x);
  26. sint threshold=x[RINDEX(N)];
  27. for(uint ee=0;ee<N;ee++)
  28. idx[ee]+=(1<<e)*(x[ee]<threshold);
  29. thresholds[e].intValue=threshold;
  30. break;
  31. }
  32. default:{
  33. //Make categorical split
  34. uint *restrict x=(uint*)(X[E].x);
  35. mask mask=RMASK(X[E].numCat);
  36. for(uint ee=0;ee<N;ee++)
  37. idx[ee]+=(1<<e)*((mask&(1<<(x[ee])))>0);
  38. thresholds[e].selection=mask;
  39. }
  40. }
  41. }
  42. //Calculate scores
  43. uint objInLeafPerClass[twoToD*numC]; //Counts of classes in a each leaf
  44. uint objInLeaf[twoToD]; //Counts of objects in each leaf
  45. uint objInBagPerClass[numC]; //Counts of classes in a bag
  46. for(uint e=0;e<numC;e++)
  47. objInBagPerClass[e]=0;
  48. for(uint e=0;e<twoToD*numC;e++)
  49. objInLeafPerClass[e]=0;
  50. for(uint e=0;e<twoToD;e++)
  51. objInLeaf[e]=0;
  52. if(!multi){
  53. //=Many-class case=
  54. //Count
  55. for(uint e=0;e<N;e++){
  56. objInLeafPerClass[Y[e]+idx[e]*numC]+=bag[e];
  57. objInLeaf[idx[e]]+=bag[e];
  58. objInBagPerClass[Y[e]]+=bag[e];
  59. }
  60. //Calculate the scores
  61. for(uint e=0;e<twoToD;e++)
  62. for(uint ee=0;ee<numC;ee++)
  63. scores[ee+e*numC]=log(
  64. ((double)objInLeafPerClass[ee+e*numC]+1)/((double)objInLeaf[e]+numC)
  65. *
  66. ((double)N+numC)/((double)objInBagPerClass[ee]+1)
  67. );
  68. //Fill the OOB scores
  69. for(uint e=0;e<N;e++)
  70. for(uint ee=0;ee<numC;ee++)
  71. oobPrMatrix[e*numC+ee]=scores[idx[e]*numC+ee];
  72. }else{
  73. //=Multi-class case=
  74. //Count
  75. for(uint e=0;e<N;e++){
  76. objInLeaf[idx[e]]+=bag[e];
  77. for(uint ee=0;ee<numC;ee++){
  78. uint toCount=bag[e]*Y[ee*N+e];
  79. objInLeafPerClass[ee+idx[e]*numC]+=toCount;
  80. objInBagPerClass[ee]+=toCount;
  81. }
  82. }
  83. //Calculate the quotient scores
  84. for(uint e=0;e<twoToD;e++)
  85. for(uint ee=0;ee<numC;ee++){
  86. scores[ee+e*numC]=log(
  87. ((double)objInLeafPerClass[ee+e*numC]+1)/((double)objInLeaf[e]-objInLeafPerClass[ee+e*numC]+1)
  88. *
  89. ((double)N-objInBagPerClass[ee]+1)/((double)objInBagPerClass[ee]+1)
  90. );
  91. }
  92. //Fill the OOB quotient scores
  93. for(uint e=0;e<N;e++)
  94. for(uint ee=0;ee<numC;ee++)
  95. oobPrMatrix[e*numC+ee]=scores[idx[e]*numC+ee];
  96. }
  97. }
  98. void predictFernAdd(PREDSET_,FERN_,double *restrict ans,uint *restrict idx,SIMP_){
  99. for(uint e=0;e<N;e++) idx[e]=0;
  100. //ans is a matrix of N columns of length numC
  101. for(uint e=0;e<D;e++){
  102. uint E=splitAtts[e];
  103. switch(X[E].numCat){
  104. case 0:{
  105. //Make numerical split
  106. double *restrict x=(double*)(X[E].x);
  107. double threshold=thresholds[e].value;
  108. for(uint ee=0;ee<N;ee++)
  109. idx[ee]+=(1<<e)*(x[ee]<threshold);
  110. break;
  111. }
  112. case -1:{
  113. //Make integer split
  114. sint *restrict x=(sint*)(X[E].x);
  115. sint threshold=thresholds[e].intValue;
  116. for(uint ee=0;ee<N;ee++)
  117. idx[ee]+=(1<<e)*(x[ee]<threshold);
  118. break;
  119. }
  120. default:{
  121. //Make categorical split
  122. uint *restrict x=(uint*)(X[E].x);
  123. mask mask=thresholds[e].selection;
  124. for(uint ee=0;ee<N;ee++)
  125. idx[ee]+=(1<<e)*((mask&(1<<(x[ee])))>0);
  126. }
  127. }
  128. }
  129. //Fill ans with actual predictions
  130. for(uint e=0;e<N;e++)
  131. for(uint ee=0;ee<numC;ee++)
  132. ans[e*numC+ee]+=scores[idx[e]*numC+ee];
  133. }
  134. accLoss calcAccLossConsistent(DATASET_,uint E,FERN_,uint *bag,uint *idx,score_t *curPreds,uint numC,uint D,R_,uint consSeed,uint *idxP,uint *idxPP){
  135. //Generate idxP. To this end, implicitly generate a permuted version of the attribute E and build split on it; then
  136. //replace this split within idx to make a copy of the fern as if it was grown on a permuted E.
  137. //We also make idxPP as idxP in a plain importance calculation.
  138. //...yet RINDEX is consistent, i.e. returns the same permutation for the same E; threshold is not
  139. rng_t _rng2,*rng2=&_rng2;
  140. rng_t *rngO=rng;
  141. for(uint e=0;e<N;e++) idxPP[e]=(idxP[e]=idx[e]);
  142. for(uint e=0;e<D;e++) if(splitAtts[e]==E){
  143. //Re-seed; different order than in makeModel for fern is intentional
  144. SETSEEDEX(rng2,consSeed,E+1);
  145. //Back to business
  146. switch(X[E].numCat){
  147. case 0:{
  148. //Numerical split
  149. double *x=(double*)(X[E].x);
  150. double threshold=thresholds[e].value;
  151. for(uint ee=0;ee<N;ee++){
  152. rng=rng2;
  153. idxP[ee]=SET_BIT(idxP[ee],e,x[RINDEX(N)]<threshold);
  154. rng=rngO;
  155. idxPP[ee]=SET_BIT(idxPP[ee],e,x[RINDEX(N)]<threshold);
  156. }
  157. rng=rngO;
  158. break;
  159. }
  160. case -1:{
  161. //Integer split
  162. sint *x=(sint*)(X[E].x);
  163. sint threshold=thresholds[e].intValue;
  164. for(uint ee=0;ee<N;ee++){
  165. rng=rng2;
  166. idxP[ee]=SET_BIT(idxP[ee],e,x[RINDEX(N)]<threshold);
  167. rng=rngO;
  168. idxPP[ee]=SET_BIT(idxPP[ee],e,x[RINDEX(N)]<threshold);
  169. }
  170. rng=rngO;
  171. break;
  172. }
  173. default:{
  174. //Categorical split
  175. uint *x=(uint*)(X[E].x);
  176. mask mask=thresholds[e].selection;
  177. for(uint ee=0;ee<N;ee++){
  178. rng=rng2;
  179. idxP[ee]=SET_BIT(idxP[ee],e,GET_BIT(mask,x[RINDEX(N)]));
  180. rng=rngO;
  181. idxPP[ee]=SET_BIT(idxPP[ee],e,GET_BIT(mask,x[RINDEX(N)]));
  182. }
  183. rng=rngO;
  184. }
  185. }
  186. }
  187. rng=rngO;
  188. //Calculate leaves for this permuted fern; first init, ...
  189. uint twoToD=1<<(D);
  190. uint objInLeafPerClassP[twoToD*numC]; //Counts of classes in a each leaf
  191. uint objInLeafP[twoToD]; //Counts of objects in each leaf
  192. uint objInBagPerClassP[numC]; //Counts of classes in a bag
  193. for(uint e=0;e<numC;e++) objInBagPerClassP[e]=0;
  194. for(uint e=0;e<twoToD*numC;e++) objInLeafPerClassP[e]=0;
  195. for(uint e=0;e<twoToD;e++) objInLeafP[e]=0;
  196. //...then fill.
  197. for(uint e=0;e<N;e++){
  198. objInLeafPerClassP[Y[e]+idxP[e]*numC]+=bag[e];
  199. objInLeafP[idxP[e]]+=bag[e];
  200. objInBagPerClassP[Y[e]]+=bag[e];
  201. }
  202. //Combine into importance scores
  203. uint objInBag=0;
  204. double sumScoreOrig=0.;
  205. double sumScoreMixed=0.;
  206. double sumPermScore=0.;
  207. double sumPermScoreMixed=0.;
  208. for(uint e=0;e<N;e++){
  209. //Finish the score on the good class from
  210. double scoreTrueClassOrig=scores[idx[e]*numC+Y[e]];
  211. double scoreTrueClassMixed=scores[idxPP[e]*numC+Y[e]];
  212. double permScoreTrueClass=log(
  213. ((double)objInLeafPerClassP[Y[e]+idxP[e]*numC]+1)/((double)objInLeafP[idxP[e]]+numC)
  214. *
  215. ((double)N+numC)/((double)objInBagPerClassP[Y[e]]+1)
  216. );
  217. double permScoreTrueClassMixed=log(
  218. ((double)objInLeafPerClassP[Y[e]+idxPP[e]*numC]+1)/((double)objInLeafP[idxPP[e]]+numC)
  219. *
  220. ((double)N+numC)/((double)objInBagPerClassP[Y[e]]+1)
  221. );
  222. sumScoreOrig+=(!(bag[e]))*scoreTrueClassOrig;
  223. sumScoreMixed+=(!(bag[e]))*scoreTrueClassMixed;
  224. sumPermScore+=(!(bag[e]))*permScoreTrueClass;
  225. sumPermScoreMixed+=(!(bag[e]))*permScoreTrueClassMixed;
  226. objInBag+=!(bag[e]);
  227. }
  228. accLoss ans;
  229. ans.direct=(sumScoreOrig-sumScoreMixed)/((double)objInBag); //The same as in regular importance
  230. ans.shadow=(sumPermScore-sumPermScoreMixed)/((double)objInBag);
  231. return(ans);
  232. }
  233. accLoss calcAccLoss(DATASET_,uint E,FERN_,uint *bag,uint *idx,score_t *curPreds,uint numC,uint D,R_,uint *idxPerm){
  234. //Generate idxPerm, idx for permuted values of an attribute E
  235. for(uint e=0;e<N;e++) idxPerm[e]=idx[e];
  236. for(uint e=0;e<D;e++) if(splitAtts[e]==E){
  237. switch(X[E].numCat){
  238. case 0:{
  239. //Numerical split
  240. double *x=(double*)(X[E].x);
  241. double threshold=thresholds[e].value;
  242. for(uint ee=0;ee<N;ee++)
  243. idxPerm[ee]=SET_BIT(idxPerm[ee],e,x[RINDEX(N)]<threshold);
  244. break;
  245. }
  246. case -1:{
  247. //Integer split
  248. sint *x=(sint*)(X[E].x);
  249. sint threshold=thresholds[e].intValue;
  250. for(uint ee=0;ee<N;ee++)
  251. idxPerm[ee]=SET_BIT(idxPerm[ee],e,x[RINDEX(N)]<threshold);
  252. break;
  253. }
  254. default:{
  255. //Categorical split
  256. uint *x=(uint*)(X[E].x);
  257. mask mask=thresholds[e].selection;
  258. for(uint ee=0;ee<N;ee++)
  259. idxPerm[ee]=SET_BIT(idxPerm[ee],e,GET_BIT(mask,x[RINDEX(N)]));
  260. }
  261. }
  262. }
  263. uint objInBag=0;
  264. double wrongDiff=0;
  265. for(uint e=0;e<N;e++){
  266. wrongDiff+=
  267. (!(bag[e]))*
  268. (scores[idx[e]*numC+Y[e]]-scores[idxPerm[e]*numC+Y[e]]);
  269. objInBag+=!(bag[e]);
  270. }
  271. accLoss ans;
  272. ans.direct=wrongDiff/((double)objInBag);
  273. return(ans);
  274. }