decision-tree.rkt 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. #lang racket
  2. (require "csv-to-list.rkt"
  3. "utils.rkt"
  4. "data-representation-abstraction.rkt")
  5. (provide (all-defined-out))
  6. (define FILE-PATH "data_banknote_authentication.csv")
  7. (define COLUMN-CONVERTERS (list string->number
  8. string->number
  9. string->number
  10. string->number
  11. (lambda (a-class) (inexact->exact (string->number a-class)))))
  12. (define data-set (all-rows FILE-PATH #:column-converters COLUMN-CONVERTERS))
  13. (define dev-data-set (list #(2.771244718 1.784783929 0)
  14. #(1.728571309 1.169761413 0)
  15. #(3.678319846 2.81281357 0)
  16. #(3.961043357 2.61995032 0)
  17. #(2.999208922 2.209014212 0)
  18. #(7.497545867 3.162953546 1)
  19. #(9.00220326 3.339047188 1)
  20. #(7.444542326 0.476683375 1)
  21. #(10.12493903 3.234550982 1)
  22. #(6.642287351 3.319983761 1)))
  23. ;; ===============
  24. ;; DATA STRUCTURES
  25. ;; ===============
  26. (struct Split (index value subsets cost)
  27. #:transparent)
  28. (struct Node (data split-feature-index split-value left right)
  29. #:transparent)
  30. (define (make-leaf-node data)
  31. (Node data
  32. 'none
  33. 'none
  34. empty
  35. empty))
  36. (define (make-leaf-node-from-split-node split-node)
  37. (struct-copy Node split-node
  38. [split-feature-index 'none]
  39. [split-value 'none]
  40. [left empty]
  41. [right empty]))
  42. (define (leaf-node? node)
  43. (and (data-empty? (Node-left node))
  44. (data-empty? (Node-right node))))
  45. (define (last-split-node? node)
  46. (cond [(leaf-node? node) false]
  47. [else
  48. (and (leaf-node? (Node-left node))
  49. (leaf-node? (Node-right node)))]))
  50. (define (node-majority-prediction node label-column-index)
  51. (data-majority-prediction (Node-data node) label-column-index))
  52. ;; =======================
  53. ;; DECISION TREE ALGORITHM
  54. ;; =======================
  55. (define (calc-proportion subset class-label label-column-index)
  56. (define (get-class-counter a-class-label)
  57. (lambda (row)
  58. (= a-class-label
  59. (data-point-get-col row label-column-index))))
  60. (cond [(data-empty? subset) 0]
  61. [else (let* ([row-count (data-length subset)]
  62. [class-count (count (get-class-counter class-label) subset)]
  63. [prop (/ class-count row-count)])
  64. (* prop (- 1.0 prop)))]))
  65. #|
  66. The procedure gini-index is used to evaluate the quality of a split.
  67. It is a cost function for a split.
  68. We want to keep the costs for splits low. (also: greedy)
  69. There are other ways of calculating the quality of a split, but for now we
  70. implement gini index.
  71. |#
  72. (define (gini-index subsets label-column-index)
  73. (for/sum ([subset (in-list subsets)])
  74. (for/sum ([label (in-list (list 0 1))])
  75. (calc-proportion subset
  76. label
  77. label-column-index))))
  78. (define (split-data data index value)
  79. (let-values ([(part1 part2)
  80. (data-partition (lambda (data-point)
  81. (< (data-point-get-col data-point index) value))
  82. data)])
  83. (list part1 part2)))
  84. (define (get-best-split data feature-column-indices label-column-index)
  85. (define-values (col-index value subsets cost)
  86. (for*/fold ([previous-best-index +inf.0]
  87. [previous-best-value +inf.0]
  88. [previous-best-subsets empty]
  89. [previous-best-cost +inf.0]) ; initial values / previous values
  90. ([col-index (in-range (sub1 (vector-length (data-first data))))]
  91. [value (in-list (data-get-col data col-index))])
  92. (let* ([current-value value]
  93. [current-index col-index]
  94. [current-subsets (split-data data col-index current-value)]
  95. [current-cost (gini-index current-subsets label-column-index)])
  96. (if (< current-cost previous-best-cost)
  97. (values current-index
  98. current-value
  99. current-subsets
  100. current-cost)
  101. (values previous-best-index
  102. previous-best-value
  103. previous-best-subsets
  104. previous-best-cost)))))
  105. (Split col-index value subsets cost))
  106. #|
  107. PREDICTING:
  108. - leaf node of the tree, majority class as prediction
  109. |#
  110. (define (predict-at-leaf-node leaf label-column-index)
  111. (node-majority-prediction leaf label-column-index))
  112. (define (fit #:train-data data
  113. #:feature-column-indices feature-column-indices
  114. #:label-column-index label-column-index
  115. #:max-depth [max-depth 6]
  116. #:min-data-points [min-data-points 12]
  117. #:min-data-points-ratio [min-data-points-ratio 0.02]
  118. #:min-impurity-split [min-impurity-split (expt 10 -7)]
  119. #:stop-at-no-impurity-improvement [stop-at-no-impurity-improvement true])
  120. (define all-data-length (data-length data))
  121. (define current-depth 1)
  122. #|
  123. STOP CRITERIA:
  124. - only one class in a subset (cannot be split any further and does not need to be split)
  125. - maximum tree depth reached
  126. - minimum number of data points in a subset
  127. - minimum ratio of data points in this subset
  128. |#
  129. (define (all-same-label? subset)
  130. (labels-elements-equal? (data-get-col subset label-column-index)))
  131. (define (insufficient-data-points-for-split? subset)
  132. (let ([number-of-data-points (data-length subset)])
  133. (or (<= number-of-data-points min-data-points)
  134. (< number-of-data-points 2))))
  135. (define (max-depth-reached? current-depth)
  136. (>= current-depth max-depth))
  137. (define (insufficient-data-points-ratio-for-split? subset)
  138. (<= (/ (data-length subset) all-data-length) min-data-points-ratio))
  139. (define (no-improvement? previous-split-impurity split-impurity)
  140. (and (<= previous-split-impurity split-impurity)
  141. stop-at-no-impurity-improvement))
  142. (define (insufficient-impurity? impurity)
  143. (< impurity min-impurity-split))
  144. #|
  145. Here we do the recursive splitting.
  146. |#
  147. (define (recursive-split subset current-depth previous-split-impurity)
  148. (display "recursive split on depth: ") (displayln current-depth)
  149. #|
  150. Before splitting further, we check for stopping early conditions.
  151. |#
  152. (cond
  153. [(max-depth-reached? current-depth)
  154. (displayln "STOPPING CONDITION: maximum depth")
  155. (displayln (string-append "INFO: still got "
  156. (number->string (data-length subset))
  157. " data points"))
  158. (make-leaf-node subset)]
  159. [(insufficient-data-points-for-split? subset)
  160. (displayln "STOPPING CONDITION: insuficient number of data points")
  161. (displayln (string-append "INFO: still got "
  162. (number->string (data-length subset))
  163. " data points"))
  164. (make-leaf-node subset)]
  165. [(insufficient-data-points-ratio-for-split? subset)
  166. (displayln "STOPPING CONDITION: insuficient ratio of data points")
  167. (displayln (string-append "INFO: still got "
  168. (number->string (data-length subset))
  169. " data points"))
  170. (make-leaf-node subset)]
  171. [(all-same-label? subset)
  172. (displayln "STOPPING CONDITION: all same label")
  173. (displayln (string-append "INFO: still got "
  174. (number->string (data-length subset))
  175. " data points"))
  176. (make-leaf-node subset)]
  177. [else
  178. (displayln (string-append "INFO: CONTINUING SPLITT: still got "
  179. (number->string (data-length subset))
  180. " data points"))
  181. ;; (display "input data for searching best split:") (displayln subset)
  182. (let* ([best-split (get-best-split subset
  183. feature-column-indices
  184. label-column-index)])
  185. (cond
  186. [(no-improvement? previous-split-impurity (Split-cost best-split))
  187. (displayln (string-append "STOPPING CONDITION: "
  188. "no improvement in impurity: previously: "
  189. (number->string previous-split-impurity) " "
  190. "now: "
  191. (number->string (Split-cost best-split))))
  192. (make-leaf-node subset)]
  193. [(insufficient-impurity? previous-split-impurity)
  194. (displayln "STOPPING CONDITION: not enough impurity for splitting further")
  195. (make-leaf-node subset)]
  196. [else
  197. #|
  198. Here are the recursive calls.
  199. This is not tail recursive, but since the data structure itself is recursive
  200. and we only have as many procedure calls as there are branches in the tree,
  201. it is OK to not be tail recursive here.
  202. |#
  203. (Node subset
  204. (Split-index best-split)
  205. (Split-value best-split)
  206. (recursive-split (car (Split-subsets best-split))
  207. (add1 current-depth)
  208. (Split-cost best-split))
  209. (recursive-split (cadr (Split-subsets best-split))
  210. (add1 current-depth)
  211. (Split-cost best-split)))]))]))
  212. (recursive-split data 1 1.0))
  213. (define (predict tree data-point label-column-index)
  214. #;(displayln tree)
  215. (cond [(leaf-node? tree)
  216. (node-majority-prediction tree label-column-index)]
  217. [else
  218. (cond [(< (data-point-get-col data-point (Node-split-feature-index tree))
  219. (Node-split-value tree))
  220. (predict (Node-left tree) data-point label-column-index)]
  221. [else (predict (Node-right tree) data-point label-column-index)])]))
  222. (define (data-predict tree data label-column-index)
  223. (data-map (lambda (data-point) (predict tree data-point label-column-index))
  224. data))
  225. (define (cross-validation-split data-set n-folds #:random-state [random-state false])
  226. (if random-state
  227. (random-seed random-state)
  228. (void))
  229. (let* ([shuffled-data-set (shuffle data-set)]
  230. [number-of-data-points (data-length shuffled-data-set)]
  231. [fold-size (exact-floor (/ number-of-data-points n-folds))])
  232. (split-into-chunks-of-size-n shuffled-data-set
  233. (exact-ceiling (/ number-of-data-points n-folds)))))
  234. (define (accuracy-metric actual-labels predicted-labels)
  235. (let ([correct-count (for/sum ([actual-label (in-list actual-labels)]
  236. [predicted-label (in-list predicted-labels)])
  237. (if (= actual-label predicted-label) 1 0))]
  238. [total-count (length actual-labels)])
  239. (/ correct-count total-count)))
  240. (define (leave-one-out-k-folds folds left-out-fold)
  241. (define leave-one-out-filter-procedure
  242. (lambda (fold)
  243. (not (equal? fold left-out-fold))))
  244. (filter leave-one-out-filter-procedure
  245. folds))
  246. (define (get-predictions tree data-set label-column-index)
  247. (for/list ([data-point data-set])
  248. (predict tree data-point label-column-index)))
  249. ;; evaluates the algorithm using cross validation split with n folds
  250. (define (evaluate-algorithm #:data-set data-set
  251. #:n-folds n-folds
  252. #:feature-column-indices feature-column-indices
  253. #:label-column-index label-column-index
  254. #:max-depth [max-depth 6]
  255. #:min-data-points [min-data-points 12]
  256. #:min-data-points-ratio [min-data-points-ratio 0.02]
  257. #:min-impurity-split [min-impurity-split (expt 10 -7)]
  258. #:stop-at-no-impurity-improvement [stop-at-no-impurity-improvement true]
  259. #:random-state [random-state false])
  260. (let ([folds (cross-validation-split data-set
  261. n-folds
  262. #:random-state random-state)])
  263. (for/list ([fold folds])
  264. (let* ([train-set (foldr append empty (leave-one-out-k-folds folds fold))]
  265. [test-set (map (lambda (data-point)
  266. (data-point-take-features data-point
  267. label-column-index))
  268. fold)]
  269. [actual-labels (data-get-col fold label-column-index)]
  270. [tree (fit #:train-data train-set
  271. #:feature-column-indices feature-column-indices
  272. #:label-column-index label-column-index
  273. #:max-depth max-depth
  274. #:min-data-points min-data-points
  275. #:min-data-points-ratio min-data-points-ratio
  276. #:min-impurity-split min-impurity-split
  277. #:stop-at-no-impurity-improvement stop-at-no-impurity-improvement)]
  278. [predicted-labels (get-predictions tree test-set label-column-index)])
  279. #;(print-tree tree label-column-index)
  280. (accuracy-metric actual-labels predicted-labels)))))
  281. ;; displays a string representation of a learned decision tree
  282. (define (print-tree tree label-column-index)
  283. (define (tree->string tree depth)
  284. (cond [(leaf-node? tree)
  285. (string-append (n-times-string " " depth)
  286. "["
  287. (number->string
  288. (node-majority-prediction tree label-column-index))
  289. "]\n")]
  290. [else
  291. (string-append
  292. (string-append (n-times-string " " depth)
  293. "[feature:"
  294. (number->string (Node-split-feature-index tree))
  295. " < "
  296. (number->string (Node-split-value tree))
  297. "]\n")
  298. (tree->string (Node-left tree) (add1 depth))
  299. (tree->string (Node-right tree) (add1 depth)))]))
  300. (displayln (tree->string tree 0)))
  301. ;; =========================================================
  302. ;; PRUNING
  303. ;; =========================================================
  304. (define (count-leaves tree)
  305. (cond [(leaf-node? tree) 1]
  306. [else (+ (count-leaves (Node-left tree))
  307. (count-leaves (Node-right tree)))]))
  308. (define (get-last-split-nodes tree)
  309. (define (traverse-collect-last-split-nodes subtree)
  310. (cond
  311. [(leaf-node? subtree) empty]
  312. [(last-split-node? subtree) (list subtree)]
  313. [(leaf-node? (Node-left subtree))
  314. (traverse-collect-last-split-nodes (Node-right subtree))]
  315. [(leaf-node? (Node-right subtree))
  316. (traverse-collect-last-split-nodes (Node-left subtree))]
  317. [else
  318. (append (traverse-collect-last-split-nodes (Node-left subtree))
  319. (traverse-collect-last-split-nodes (Node-right subtree)))]))
  320. (flatten (traverse-collect-last-split-nodes tree)))
  321. #|This procedure returns the better tree according to the accuracy metric on the
  322. pruning set.|#
  323. (define (select-better-tree tree
  324. pruned-tree
  325. pruning-set
  326. feature-column-indices
  327. label-column-index
  328. accuracy-tolerance)
  329. (let ([actual-labels (data-get-col pruning-set
  330. label-column-index)]
  331. [tree-predicted-labels (data-predict tree
  332. pruning-set
  333. label-column-index)]
  334. [pruned-tree-predicted-labels (data-predict pruned-tree
  335. pruning-set
  336. label-column-index)])
  337. (let ([tree-accuracy (accuracy-metric actual-labels
  338. tree-predicted-labels)]
  339. [pruned-tree-accuracy (accuracy-metric actual-labels
  340. pruned-tree-predicted-labels)])
  341. #;(displayln (string-append "accuracy tree: " (number->string tree-accuracy)))
  342. #;(displayln (string-append "accuracy pruned-tree: " (number->string pruned-tree-accuracy)))
  343. (cond [(< (abs (- tree-accuracy pruned-tree-accuracy)) accuracy-tolerance)
  344. pruned-tree]
  345. [else tree]))))
  346. (define (prune-node-from-tree tree split-node)
  347. (cond [(leaf-node? tree) tree]
  348. [(equal? tree split-node)
  349. (make-leaf-node-from-split-node tree)]
  350. [else (struct-copy Node tree
  351. [left
  352. (prune-node-from-tree (Node-left tree)
  353. split-node)]
  354. [right
  355. (prune-node-from-tree (Node-right tree)
  356. split-node)])]))
  357. (define (prune-with-pruning-set tree
  358. pruning-set
  359. feature-column-indices
  360. label-column-index
  361. #:tolerance [tolerance 0.0])
  362. (define (iter-split-nodes tree remaining-split-nodes)
  363. (cond [(empty? remaining-split-nodes) tree]
  364. [else
  365. #;(displayln "REMAINING-SPLIT-NODES:")
  366. #;(displayln remaining-split-nodes)
  367. (iter-split-nodes
  368. (select-better-tree tree
  369. (prune-node-from-tree tree (first remaining-split-nodes))
  370. pruning-set
  371. feature-column-indices
  372. label-column-index
  373. tolerance)
  374. (rest remaining-split-nodes))]))
  375. (define (iter-trees tree tree-leaves#)
  376. (let* ([pruned-tree (iter-split-nodes tree (get-last-split-nodes tree))]
  377. [pruned-tree-leaves# (count-leaves pruned-tree)])
  378. ;;(displayln "tree: ") (displayln tree)
  379. ;;(displayln "pruned tree: ") (displayln pruned-tree)
  380. (cond
  381. ;; in the previous call to iter-split-nodes leaves were removed
  382. ;; by pruning the tree. This means that all last split nodes cannot
  383. ;; be removed and thus we finished the pruning process.
  384. [(= pruned-tree-leaves# tree-leaves#)
  385. (displayln "STOPPING CONDITION (PRUNING): pruning further would decrease accuracy beyong tolerance")
  386. tree]
  387. ;; in the last call to iter-split-nodes leaves were removed,
  388. ;; so there is at least one new last split node and we need
  389. ;; to try to prune that
  390. [else
  391. (displayln "CONTINUING PRUNING: tree lost nodes in previous iteration of pruning")
  392. (iter-trees pruned-tree pruned-tree-leaves#)])))
  393. (iter-trees tree (count-leaves tree)))
  394. #|
  395. - remove all splits with less improvement than x in cost?
  396. - but this can be done already with early stopping parameters!
  397. |#
  398. ;; =========================================================
  399. ;; RUNNING
  400. ;; =========================================================
  401. #|
  402. (define shuffled-data-set (shuffle data-set))
  403. (define small-data-set
  404. (data-range shuffled-data-set
  405. 0
  406. (exact-floor (/ (data-length shuffled-data-set)
  407. 5))))
  408. (collect-garbage)
  409. (collect-garbage)
  410. (collect-garbage)
  411. (time
  412. (for/list ([i (in-range 1)])
  413. (mean
  414. (evaluate-algorithm #:data-set (shuffle data-set)
  415. #:n-folds 10
  416. #:feature-column-indices (list 0 1 2 3)
  417. #:label-column-index 4
  418. #:max-depth 5
  419. #:min-data-points 24
  420. #:min-data-points-ratio 0.02
  421. #:min-impurity-split (expt 10 -7)
  422. #:stop-at-no-impurity-improvement true
  423. #:random-state 0))))
  424. (collect-garbage)
  425. (collect-garbage)
  426. (collect-garbage)
  427. #;(time
  428. (for/list ([i (in-range 1)])
  429. (define tree (fit #:train-data (shuffle data-set)
  430. #:feature-column-indices (list 0 1 2 3)
  431. #:label-column-index 4
  432. #:max-depth 5
  433. #:min-data-points 12
  434. #:min-data-points-ratio 0.02
  435. #:min-impurity-split (expt 10 -7)
  436. #:stop-at-no-impurity-improvement true))
  437. 'done))
  438. |#
  439. #|
  440. IMPROVEMENTS:
  441. - remove data from not leaf nodes by using struct setters
  442. - find the remaining randomness (if there is any) which is not determined by random-state keyword arguments yet (why am I not getting the same result every time?) - maybe shuffle needs to be parametrized with a random seed instead of merely setting the seed before calling shuffle?
  443. - return not only the predicted label, but also how sure we are about the prediction (percentage of data points in the leaf node, which has the predicted label)
  444. |#