prediction-test.scm 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. (use-modules
  2. ;; SRFI 64 for unit testing facilities
  3. (srfi srfi-64)
  4. ;; SRFI 8 for `receive` form
  5. (srfi srfi-8)
  6. ;; utils - the code to be tested
  7. (decision-tree)
  8. ;; Utilities for testing
  9. (utils test)
  10. ;; Dependencies for testing the code to be tested
  11. (dataset)
  12. (metrics)
  13. (pruning)
  14. (prediction)
  15. (data-point)
  16. (tree))
  17. (define TEST-DATA
  18. (list #(2.771244718 1.784783929 0)
  19. #(1.728571309 1.169761413 0)
  20. #(3.678319846 2.81281357 0)
  21. #(3.961043357 2.61995032 0)
  22. #(2.999208922 2.209014212 0)
  23. #(7.497545867 3.162953546 1)
  24. #(9.00220326 3.339047188 1)
  25. #(7.444542326 0.476683375 1)
  26. #(10.12493903 3.234550982 1)
  27. #(6.642287351 3.319983761 1)))
  28. (define PRECISION (expt 10 -9))
  29. (test-begin "prediction-test")
  30. (test-group
  31. "predict-at-leaf-node"
  32. (test-equal
  33. 1
  34. (predict-at-leaf-node (make-leaf-node (list #(1.0 2.0 0)
  35. #(3.0 4.0 0)
  36. #(5.0 6.0 1)
  37. #(7.0 8.0 1)
  38. #(9.0 0.0 1)))
  39. 2))
  40. (test-equal
  41. 0
  42. (predict-at-leaf-node (make-leaf-node (list #(1.0 2.0 0)
  43. #(3.0 4.0 0)
  44. #(5.0 6.0 0)
  45. #(7.0 8.0 1)
  46. #(9.0 0.0 1)))
  47. 2)))
  48. (test-group
  49. "dataset-majority-prediction"
  50. (test-equal
  51. 1
  52. (dataset-majority-prediction (list #(2.3 1.1 0)
  53. #(2.0 1.1 0)
  54. #(2.3 1.0 1)
  55. #(2.0 1.0 1)
  56. #(2.3 1.0 1)
  57. #(2.0 1.0 1)
  58. #(2.4 1.0 1))
  59. 2))
  60. (test-equal
  61. 0
  62. (dataset-majority-prediction (list #(2.3 1.1 0)
  63. #(2.0 1.1 0)
  64. #(2.3 1.0 0)
  65. #(2.0 1.0 0)
  66. #(2.3 1.0 1)
  67. #(2.0 1.0 1)
  68. #(2.4 1.0 1))
  69. 2))
  70. (test-equal
  71. 0
  72. (dataset-majority-prediction (list #(2.3 1.1 0)
  73. #(2.0 1.1 0)
  74. #(2.3 1.0 0)
  75. #(2.0 1.0 0)
  76. #(2.3 1.0 1)
  77. #(2.0 1.0 1)
  78. #(2.4 1.0 1)
  79. #(2.4 1.0 1))
  80. 2)))
  81. #;(test-group
  82. "node-majority-prediction"
  83. (test-equal
  84. 1
  85. (node-majority-prediction
  86. (make-node (list #(2.3 1.1 0)
  87. #(2.0 1.1 0)
  88. #(2.3 1.0 1)
  89. #(2.0 1.0 1)
  90. #(2.3 1.0 1)
  91. #(2.0 1.0 1)
  92. #(2.4 1.0 1))
  93. 1
  94. 1.1
  95. (list #(2.3 1.0 1)
  96. #(2.0 1.0 1)
  97. #(2.3 1.0 1)
  98. #(2.0 1.0 1)
  99. #(2.4 1.0 1))
  100. (list #(2.3 1.1 0)
  101. #(2.0 1.1 0)))
  102. 2))
  103. (test-equal
  104. 0
  105. (node-majority-prediction
  106. (make-node (list #(2.3 1.1 1)
  107. #(2.0 1.1 1)
  108. #(2.3 1.0 0)
  109. #(2.0 1.0 0)
  110. #(2.3 1.0 0)
  111. #(2.0 1.0 0)
  112. #(2.4 1.0 0))
  113. 1
  114. 1.1
  115. (list #(2.3 1.0 0)
  116. #(2.0 1.0 0)
  117. #(2.3 1.0 0)
  118. #(2.0 1.0 0)
  119. #(2.4 1.0 0))
  120. (list #(2.3 1.1 1)
  121. #(2.0 1.1 1)))
  122. 2)))
  123. (test-group
  124. "predict"
  125. (let ([tree (make-node (list #(1.0 1.0 0)
  126. #(1.2 1.0 0)
  127. #(1.1 1.0 0)
  128. #(1.4 1.0 0)
  129. #(1.2 1.0 0)
  130. #(1.2 1.0 0) ;
  131. #(2.3 1.1 0)
  132. #(2.0 1.1 0) ;;
  133. #(2.3 1.0 1)
  134. #(2.0 1.0 1)
  135. #(2.3 1.0 1)
  136. #(2.0 1.0 1)
  137. #(2.4 1.0 1))
  138. 0 ;; split index
  139. 2.0 ;; split value
  140. (make-leaf-node (list #(1.0 1.0 0)
  141. #(1.2 1.0 0)
  142. #(1.1 1.0 0)
  143. #(1.4 1.0 0)
  144. #(1.2 1.0 0)
  145. #(1.2 1.0 0)))
  146. (make-node (list #(2.3 1.1 0)
  147. #(2.0 1.1 0)
  148. #(2.3 1.0 1)
  149. #(2.0 1.0 1)
  150. #(2.3 1.0 1)
  151. #(2.0 1.0 1)
  152. #(2.4 1.0 1))
  153. 1
  154. 1.1
  155. (make-leaf-node (list #(2.3 1.0 1)
  156. #(2.0 1.0 1)
  157. #(2.3 1.0 1)
  158. #(2.0 1.0 1)
  159. #(2.4 1.0 1)))
  160. (make-leaf-node (list #(2.3 1.1 0)
  161. #(2.0 1.1 0)))))])
  162. (test-equal (predict tree #(2.3 1.1 0) 2) 0)
  163. (test-equal (predict tree #(2.3 1.0 0) 2) 1)))
  164. (test-group
  165. "predict-dataset"
  166. (let ([tree (make-node (list #(1.0 1.0 0)
  167. #(1.2 1.0 0)
  168. #(1.1 1.0 0)
  169. #(1.4 1.0 0)
  170. #(1.2 1.0 0)
  171. #(1.2 1.0 0) ;
  172. #(2.3 1.1 0)
  173. #(2.0 1.1 0) ;;
  174. #(2.3 1.0 1)
  175. #(2.0 1.0 1)
  176. #(2.3 1.0 1)
  177. #(2.0 1.0 1)
  178. #(2.4 1.0 1))
  179. 0
  180. 2.0
  181. (make-leaf-node (list #(1.0 1.0 0)
  182. #(1.2 1.0 0)
  183. #(1.1 1.0 0)
  184. #(1.4 1.0 0)
  185. #(1.2 1.0 0)
  186. #(1.2 1.0 0)))
  187. (make-node (list #(2.3 1.1 0)
  188. #(2.0 1.1 0)
  189. #(2.3 1.0 1)
  190. #(2.0 1.0 1)
  191. #(2.3 1.0 1)
  192. #(2.0 1.0 1)
  193. #(2.4 1.0 1))
  194. 1
  195. 1.1
  196. (make-leaf-node (list #(2.3 1.0 1)
  197. #(2.0 1.0 1)
  198. #(2.3 1.0 1)
  199. #(2.0 1.0 1)
  200. #(2.4 1.0 1)))
  201. (make-leaf-node (list #(2.3 1.1 0)
  202. #(2.0 1.1 0)))))])
  203. (test-equal
  204. (list 0 1)
  205. (predict-dataset tree (list #(2.3 1.1 0)
  206. #(2.3 1.0 0))
  207. 2))))
  208. (test-group
  209. "predict-dataset"
  210. (let ([tree (make-node (list #(2.3 1.1 0)
  211. #(2.0 1.1 0)
  212. #(2.3 1.0 1)
  213. #(2.0 1.0 1)
  214. #(2.3 1.0 1)
  215. #(2.0 1.0 1)
  216. #(2.4 1.0 1))
  217. 1
  218. 1.1
  219. (make-leaf-node (list #(2.3 1.0 1)
  220. #(2.0 1.0 1)
  221. #(2.3 1.0 1)
  222. #(2.0 1.0 1)
  223. #(2.4 1.0 1)))
  224. (make-leaf-node (list #(2.3 1.1 0)
  225. #(2.0 1.1 0))))])
  226. (test-equal "prediction for a tiny dataset"
  227. (list 0 1 0 1)
  228. (predict-dataset tree
  229. (list #(2.4 1.2)
  230. #(1.9 0.9)
  231. #(3.0 3.0)
  232. #(0.0 0.5))
  233. 2))))
  234. (test-end "prediction-test")