pruning-test.scm 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. (use-modules
  2. ;; SRFI 64 for unit testing facilities
  3. (srfi srfi-64)
  4. ;; SRFI 8 for `receive` form
  5. (srfi srfi-8)
  6. ;; utils - the code to be tested
  7. (decision-tree)
  8. ;; Utilities for testing
  9. (utils test)
  10. ;; Dependencies for testing the code to be tested
  11. (dataset)
  12. (data-point)
  13. (tree)
  14. (pruning))
  15. (define TEST-DATA
  16. (list #(2.771244718 1.784783929 0)
  17. #(1.728571309 1.169761413 0)
  18. #(3.678319846 2.81281357 0)
  19. #(3.961043357 2.61995032 0)
  20. #(2.999208922 2.209014212 0)
  21. #(7.497545867 3.162953546 1)
  22. #(9.00220326 3.339047188 1)
  23. #(7.444542326 0.476683375 1)
  24. #(10.12493903 3.234550982 1)
  25. #(6.642287351 3.319983761 1)))
  26. (define PRECISION (expt 10 -9))
  27. (test-begin "pruning-test")
  28. (test-group
  29. "count-leaves"
  30. (test-equal
  31. 2
  32. (count-leaves (make-node (list #(2.3 1.1 0)
  33. #(2.0 1.1 0)
  34. #(2.3 1.0 1)
  35. #(2.0 1.0 1)
  36. #(2.3 1.0 1)
  37. #(2.0 1.0 1)
  38. #(2.4 1.0 1))
  39. 1
  40. 1.1
  41. (make-leaf-node (list #(2.3 1.0 1)
  42. #(2.0 1.0 1)
  43. #(2.3 1.0 1)
  44. #(2.0 1.0 1)
  45. #(2.4 1.0 1)))
  46. (make-leaf-node (list #(2.3 1.1 0)
  47. #(2.0 1.1 0))))))
  48. (test-equal
  49. 3
  50. (count-leaves (make-node (list #(2.3 1.1 3.0 0)
  51. #(2.0 1.1 3.0 0)
  52. #(2.3 1.0 4.0 0)
  53. #(2.0 1.0 3.0 1)
  54. #(2.3 1.0 3.0 1)
  55. #(2.0 1.0 3.0 1)
  56. #(2.4 1.0 3.0 1))
  57. 1
  58. 1.1
  59. (make-node (list #(2.3 1.0 4.0 0)
  60. #(2.0 1.0 3.0 1)
  61. #(2.3 1.0 3.0 1)
  62. #(2.0 1.0 3.0 1)
  63. #(2.4 1.0 3.0 1))
  64. 2
  65. 4.0
  66. (make-leaf-node (list #(2.0 1.0 3.0 1)
  67. #(2.3 1.0 3.0 1)
  68. #(2.0 1.0 3.0 1)
  69. #(2.4 1.0 3.0 1)))
  70. (make-leaf-node (list #(2.3 1.0 4.0 0))))
  71. (make-leaf-node (list #(2.3 1.1 3.0 0)
  72. #(2.0 1.1 3.0 0)))))))
  73. (test-group
  74. "get-last-split-nodes"
  75. (test-equal
  76. (list (make-node (list #(2.3 1.0 4.0 0)
  77. #(2.0 1.0 3.0 1)
  78. #(2.3 1.0 3.0 1)
  79. #(2.0 1.0 3.0 1)
  80. #(2.4 1.0 3.0 1))
  81. 2
  82. 4.0
  83. (make-leaf-node (list #(2.0 1.0 3.0 1)
  84. #(2.3 1.0 3.0 1)
  85. #(2.0 1.0 3.0 1)
  86. #(2.4 1.0 3.0 1)))
  87. (make-leaf-node (list #(2.3 1.0 4.0 0)))))
  88. (get-last-split-nodes
  89. (make-node (list #(2.3 1.1 3.0 0)
  90. #(2.0 1.1 3.0 0)
  91. #(2.3 1.0 4.0 0)
  92. #(2.0 1.0 3.0 1)
  93. #(2.3 1.0 3.0 1)
  94. #(2.0 1.0 3.0 1)
  95. #(2.4 1.0 3.0 1))
  96. 1
  97. 1.1
  98. (make-node (list #(2.3 1.0 4.0 0)
  99. #(2.0 1.0 3.0 1)
  100. #(2.3 1.0 3.0 1)
  101. #(2.0 1.0 3.0 1)
  102. #(2.4 1.0 3.0 1))
  103. 2
  104. 4.0
  105. (make-leaf-node (list #(2.0 1.0 3.0 1)
  106. #(2.3 1.0 3.0 1)
  107. #(2.0 1.0 3.0 1)
  108. #(2.4 1.0 3.0 1)))
  109. (make-leaf-node (list #(2.3 1.0 4.0 0))))
  110. (make-leaf-node (list #(2.3 1.1 3.0 0)
  111. #(2.0 1.1 3.0 0)))))))
  112. (test-group
  113. "prune-node-from-tree"
  114. (let* ([tree (make-node (list #(1.0 1.0 0)
  115. #(1.2 1.0 0)
  116. #(1.1 1.0 0)
  117. #(1.4 1.0 0)
  118. #(1.2 1.0 0)
  119. #(1.2 1.0 0) ;
  120. #(2.3 1.1 0)
  121. #(2.0 1.1 0) ;;
  122. #(2.3 1.0 1)
  123. #(2.0 1.0 1)
  124. #(2.3 1.0 1)
  125. #(2.0 1.0 1)
  126. #(2.4 1.0 1))
  127. 0
  128. 2.0
  129. (make-leaf-node (list #(1.0 1.0 0)
  130. #(1.2 1.0 0)
  131. #(1.1 1.0 0)
  132. #(1.4 1.0 0)
  133. #(1.2 1.0 0)
  134. #(1.2 1.0 0)))
  135. (make-node (list #(2.3 1.1 0)
  136. #(2.0 1.1 0)
  137. #(2.3 1.0 1)
  138. #(2.0 1.0 1)
  139. #(2.3 1.0 1)
  140. #(2.0 1.0 1)
  141. #(2.4 1.0 1))
  142. 1
  143. 1.1
  144. (make-leaf-node (list #(2.3 1.0 1)
  145. #(2.0 1.0 1)
  146. #(2.3 1.0 1)
  147. #(2.0 1.0 1)
  148. #(2.4 1.0 1)))
  149. (make-leaf-node (list #(2.3 1.1 0)
  150. #(2.0 1.1 0)))))]
  151. [split-node (make-node (list #(2.3 1.1 0)
  152. #(2.0 1.1 0)
  153. #(2.3 1.0 1)
  154. #(2.0 1.0 1)
  155. #(2.3 1.0 1)
  156. #(2.0 1.0 1)
  157. #(2.4 1.0 1))
  158. 1
  159. 1.1
  160. (make-leaf-node (list #(2.3 1.0 1)
  161. #(2.0 1.0 1)
  162. #(2.3 1.0 1)
  163. #(2.0 1.0 1)
  164. #(2.4 1.0 1)))
  165. (make-leaf-node (list #(2.3 1.1 0)
  166. #(2.0 1.1 0))))]
  167. [tranformed-node (make-leaf-node-from-split-node split-node)])
  168. (test-equal
  169. (make-node (list #(1.0 1.0 0)
  170. #(1.2 1.0 0)
  171. #(1.1 1.0 0)
  172. #(1.4 1.0 0)
  173. #(1.2 1.0 0)
  174. #(1.2 1.0 0) ;
  175. #(2.3 1.1 0)
  176. #(2.0 1.1 0) ;;
  177. #(2.3 1.0 1)
  178. #(2.0 1.0 1)
  179. #(2.3 1.0 1)
  180. #(2.0 1.0 1)
  181. #(2.4 1.0 1))
  182. 0
  183. 2.0
  184. (make-leaf-node (list #(1.0 1.0 0)
  185. #(1.2 1.0 0)
  186. #(1.1 1.0 0)
  187. #(1.4 1.0 0)
  188. #(1.2 1.0 0)
  189. #(1.2 1.0 0)))
  190. tranformed-node)
  191. (prune-node-from-tree tree split-node))))
  192. (test-group
  193. "select-better-tree"
  194. (let ([tree (make-node (list #(1.0 0.10 0)
  195. #(1.1 0.11 0)
  196. #(1.2 0.12 0)
  197. #(1.3 0.13 0)
  198. #(1.4 0.14 0)
  199. #(1.5 0.15 0)
  200. #(1.6 0.16 0)
  201. #(1.7 0.17 0)
  202. #(1.8 0.18 0)
  203. #(1.9 0.19 0)
  204. #(2.0 0.20 0)
  205. #(2.1 0.21 0)
  206. #(2.2 0.22 0)
  207. #(2.3 0.23 0)
  208. #(2.4 0.24 0)
  209. #(2.5 0.25 0)
  210. #(2.6 0.26 0)
  211. #(2.7 0.27 0)
  212. #(3.0 0.10 0)
  213. #(3.0 0.20 1))
  214. 0
  215. 3.0
  216. (make-leaf-node (list #(1.0 0.10 0)
  217. #(1.1 0.11 0)
  218. #(1.2 0.12 0)
  219. #(1.3 0.13 0)
  220. #(1.4 0.14 0)
  221. #(1.5 0.15 0)
  222. #(1.6 0.16 0)
  223. #(1.7 0.17 0)
  224. #(1.8 0.18 0)
  225. #(1.9 0.19 0)
  226. #(2.0 0.20 0)
  227. #(2.1 0.21 0)
  228. #(2.2 0.22 0)
  229. #(2.3 0.23 0)
  230. #(2.4 0.24 0)
  231. #(2.5 0.25 0)
  232. #(2.6 0.26 0)
  233. #(2.7 0.27 0)))
  234. (make-node (list #(3.0 0.10 0) ; the node, which will be pruned away
  235. #(3.0 0.20 1))
  236. 1
  237. 0.2
  238. (make-leaf-node (list #(3.0 0.10 0)))
  239. (make-leaf-node (list #(3.0 0.20 1)))))]
  240. [pruned-tree (make-node (list #(1.0 0.10 0)
  241. #(1.1 0.11 0)
  242. #(1.2 0.12 0)
  243. #(1.3 0.13 0)
  244. #(1.4 0.14 0)
  245. #(1.5 0.15 0)
  246. #(1.6 0.16 0)
  247. #(1.7 0.17 0)
  248. #(1.8 0.18 0)
  249. #(1.9 0.19 0)
  250. #(2.0 0.20 0)
  251. #(2.1 0.21 0)
  252. #(2.2 0.22 0)
  253. #(2.3 0.23 0)
  254. #(2.4 0.24 0)
  255. #(2.5 0.25 0)
  256. #(2.6 0.26 0)
  257. #(2.7 0.27 0)
  258. #(3.0 0.10 0)
  259. #(3.0 0.20 1))
  260. 0
  261. 3.0
  262. (make-leaf-node (list #(1.0 0.10 0)
  263. #(1.1 0.11 0)
  264. #(1.2 0.12 0)
  265. #(1.3 0.13 0)
  266. #(1.4 0.14 0)
  267. #(1.5 0.15 0)
  268. #(1.6 0.16 0)
  269. #(1.7 0.17 0)
  270. #(1.8 0.18 0)
  271. #(1.9 0.19 0)
  272. #(2.0 0.20 0)
  273. #(2.1 0.21 0)
  274. #(2.2 0.22 0)
  275. #(2.3 0.23 0)
  276. #(2.4 0.24 0)
  277. #(2.5 0.25 0)
  278. #(2.6 0.26 0)
  279. #(2.7 0.27 0)))
  280. (make-leaf-node (list #(3.0 0.10 0) ; the pruned node
  281. #(3.0 0.20 1))))]
  282. ;; the pruning set is only +0.01 and +0.001 in all rows
  283. ;; so that the values do not pass decision boundaries
  284. [pruning-set (list #(1.01 0.101 0)
  285. #(1.11 0.111 0)
  286. #(1.21 0.121 0)
  287. #(1.31 0.131 0)
  288. #(1.41 0.141 0)
  289. #(1.51 0.151 0)
  290. #(1.61 0.161 0)
  291. #(1.71 0.171 0)
  292. #(1.81 0.181 0)
  293. #(1.91 0.191 0)
  294. #(2.01 0.201 0)
  295. #(2.11 0.211 0)
  296. #(2.21 0.221 0)
  297. #(2.31 0.231 0)
  298. #(2.41 0.241 0)
  299. #(2.51 0.251 0)
  300. #(2.61 0.261 0)
  301. #(2.71 0.271 0)
  302. #(3.01 0.101 0)
  303. #(3.01 0.201 1))]
  304. [feature-column-indices (list 0 1)]
  305. [label-column-index 2]
  306. ;; 6% classification error tolerance,
  307. ;; so that 1 of 20 data points misclassification does not matter
  308. [accuracy-tolerance 0.06])
  309. ;; since the 5% improvement are below the tolerance for lost accuracy
  310. ;; when pruning, the tree should indeed be pruned.
  311. (test-equal
  312. pruned-tree
  313. (select-better-tree tree
  314. pruned-tree
  315. pruning-set
  316. feature-column-indices
  317. label-column-index
  318. accuracy-tolerance))
  319. ;; now try with a lower tolerance, which should not allow pruning
  320. (test-equal
  321. tree
  322. (select-better-tree tree
  323. pruned-tree
  324. pruning-set
  325. feature-column-indices
  326. label-column-index
  327. 0.04))))
  328. (test-group
  329. "prune-with-pruning-set"
  330. (let ([tree (make-node (list #(1.0 0.10 0)
  331. #(1.1 0.11 0)
  332. #(1.2 0.12 0)
  333. #(1.3 0.13 0)
  334. #(1.4 0.14 0)
  335. #(1.5 0.15 0)
  336. #(1.6 0.16 0)
  337. #(1.7 0.17 0)
  338. #(1.8 0.18 0)
  339. #(1.9 0.19 0)
  340. #(2.0 0.10 1)
  341. #(2.1 0.11 1)
  342. #(2.2 0.12 1)
  343. #(2.3 0.13 1)
  344. #(2.4 0.14 1)
  345. #(2.5 0.15 1)
  346. #(2.6 0.16 1)
  347. #(2.7 0.17 1)
  348. #(3.0 0.00 1)
  349. #(3.0 0.01 0))
  350. 0
  351. 2.0
  352. (make-leaf-node (list #(1.0 0.10 0)
  353. #(1.1 0.11 0)
  354. #(1.2 0.12 0)
  355. #(1.3 0.13 0)
  356. #(1.4 0.14 0)
  357. #(1.5 0.15 0)
  358. #(1.6 0.16 0)
  359. #(1.7 0.17 0)
  360. #(1.8 0.18 0)
  361. #(1.9 0.19 0)))
  362. (make-node (list #(2.0 0.10 1)
  363. #(2.1 0.11 1)
  364. #(2.2 0.12 1)
  365. #(2.3 0.13 1)
  366. #(2.4 0.14 1)
  367. #(2.5 0.15 1)
  368. #(2.6 0.16 1)
  369. #(2.7 0.17 1)
  370. #(3.0 0.00 1)
  371. #(3.0 0.01 0))
  372. 1
  373. 0.1
  374. (make-node (list #(3.0 0.00 1)
  375. #(3.0 0.01 0))
  376. 1
  377. 0.01
  378. (make-leaf-node (list #(3.0 0.00 1)))
  379. (make-leaf-node (list #(3.0 0.01 0))))
  380. (make-leaf-node (list #(2.0 0.10 1)
  381. #(2.1 0.11 1)
  382. #(2.2 0.12 1)
  383. #(2.3 0.13 1)
  384. #(2.4 0.14 1)
  385. #(2.5 0.15 1)
  386. #(2.6 0.16 1)
  387. #(2.7 0.17 1)))))]
  388. [pruned-tree (make-node (list #(1.0 0.10 0)
  389. #(1.1 0.11 0)
  390. #(1.2 0.12 0)
  391. #(1.3 0.13 0)
  392. #(1.4 0.14 0)
  393. #(1.5 0.15 0)
  394. #(1.6 0.16 0)
  395. #(1.7 0.17 0)
  396. #(1.8 0.18 0)
  397. #(1.9 0.19 0)
  398. #(2.0 0.10 1)
  399. #(2.1 0.11 1)
  400. #(2.2 0.12 1)
  401. #(2.3 0.13 1)
  402. #(2.4 0.14 1)
  403. #(2.5 0.15 1)
  404. #(2.6 0.16 1)
  405. #(2.7 0.17 1)
  406. #(3.0 0.00 1)
  407. #(3.0 0.01 0))
  408. 0
  409. 2.0
  410. (make-leaf-node (list #(1.0 0.10 0)
  411. #(1.1 0.11 0)
  412. #(1.2 0.12 0)
  413. #(1.3 0.13 0)
  414. #(1.4 0.14 0)
  415. #(1.5 0.15 0)
  416. #(1.6 0.16 0)
  417. #(1.7 0.17 0)
  418. #(1.8 0.18 0)
  419. #(1.9 0.19 0)))
  420. (make-leaf-node (list #(2.0 0.10 1)
  421. #(2.1 0.11 1)
  422. #(2.2 0.12 1)
  423. #(2.3 0.13 1)
  424. #(2.4 0.14 1)
  425. #(2.5 0.15 1)
  426. #(2.6 0.16 1)
  427. #(2.7 0.17 1)
  428. #(3.0 0.00 1)
  429. #(3.0 0.01 0))))]
  430. [pruning-set (list #(1.01 0.101 0)
  431. #(1.11 0.111 0)
  432. #(1.21 0.121 0)
  433. #(1.31 0.131 0)
  434. #(1.41 0.141 0)
  435. #(1.51 0.151 0)
  436. #(1.61 0.161 0)
  437. #(1.71 0.171 0)
  438. #(1.81 0.181 0)
  439. #(1.91 0.191 0)
  440. #(2.01 0.101 1)
  441. #(2.11 0.111 1)
  442. #(2.21 0.121 1)
  443. #(2.31 0.131 1)
  444. #(2.41 0.141 1)
  445. #(2.51 0.151 1)
  446. #(2.61 0.161 1)
  447. #(2.71 0.171 1)
  448. #(3.01 0.001 1)
  449. #(3.01 0.011 0))])
  450. (test-equal
  451. pruned-tree
  452. (prune-with-pruning-set tree
  453. pruning-set
  454. (list 0 1)
  455. 2
  456. #:tolerance 0.06)))
  457. (let ([tree (make-node (list #(1.0 0.10 0)
  458. #(1.1 0.11 0)
  459. #(1.2 0.12 0)
  460. #(1.3 0.13 0)
  461. #(1.4 0.14 0)
  462. #(1.5 0.15 0)
  463. #(1.6 0.16 0)
  464. #(1.7 0.17 0)
  465. #(1.8 0.18 0)
  466. #(1.9 0.19 0)
  467. #(2.0 0.10 1)
  468. #(2.1 0.11 1)
  469. #(2.2 0.12 1)
  470. #(2.3 0.13 1)
  471. #(2.4 0.14 1)
  472. #(2.5 0.15 1)
  473. #(2.6 0.16 1)
  474. #(2.7 0.17 1)
  475. #(3.0 0.00 1)
  476. #(3.0 0.01 0))
  477. 0
  478. 2.0
  479. (make-leaf-node (list #(1.0 0.10 0)
  480. #(1.1 0.11 0)
  481. #(1.2 0.12 0)
  482. #(1.3 0.13 0)
  483. #(1.4 0.14 0)
  484. #(1.5 0.15 0)
  485. #(1.6 0.16 0)
  486. #(1.7 0.17 0)
  487. #(1.8 0.18 0)
  488. #(1.9 0.19 0)))
  489. (make-node (list #(2.0 0.10 1)
  490. #(2.1 0.11 1)
  491. #(2.2 0.12 1)
  492. #(2.3 0.13 1)
  493. #(2.4 0.14 1)
  494. #(2.5 0.15 1)
  495. #(2.6 0.16 1)
  496. #(2.7 0.17 1)
  497. #(3.0 0.00 1)
  498. #(3.0 0.01 0))
  499. 1
  500. 0.1
  501. (make-node (list #(3.0 0.00 1)
  502. #(3.0 0.01 0))
  503. 1
  504. 0.01
  505. (make-leaf-node (list #(3.0 0.00 1)))
  506. (make-leaf-node (list #(3.0 0.01 0))))
  507. (make-leaf-node (list #(2.0 0.10 1)
  508. #(2.1 0.11 1)
  509. #(2.2 0.12 1)
  510. #(2.3 0.13 1)
  511. #(2.4 0.14 1)
  512. #(2.5 0.15 1)
  513. #(2.6 0.16 1)
  514. #(2.7 0.17 1)))))]
  515. [pruning-set (list #(1.01 0.101 0)
  516. #(1.11 0.111 0)
  517. #(1.21 0.121 0)
  518. #(1.31 0.131 0)
  519. #(1.41 0.141 0)
  520. #(1.51 0.151 0)
  521. #(1.61 0.161 0)
  522. #(1.71 0.171 0)
  523. #(1.81 0.181 0)
  524. #(1.91 0.191 0)
  525. #(2.01 0.101 1)
  526. #(2.11 0.111 1)
  527. #(2.21 0.121 1)
  528. #(2.31 0.131 1)
  529. #(2.41 0.141 1)
  530. #(2.51 0.151 1)
  531. #(2.61 0.161 1)
  532. #(2.71 0.171 1)
  533. #(3.01 0.001 1)
  534. #(3.01 0.011 0))])
  535. (test-equal
  536. tree
  537. (prune-with-pruning-set tree
  538. pruning-set
  539. (list 0 1)
  540. 2
  541. #:tolerance 0.04))))
  542. ;; TODO: missing test: traverse-collect-last-split-nodes
  543. (test-end "pruning-test")