decision-tree.scm 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. (define-module (decision-tree))
  2. (use-modules
  3. ;; SRFI-1 for list procedures
  4. ((srfi srfi-1) #:prefix srfi1:)
  5. ;; SRFI-8 for `receive` form
  6. (srfi srfi-8)
  7. (utils csv)
  8. (utils display)
  9. (utils math)
  10. (utils string)
  11. (utils list)
  12. (dataset)
  13. (data-point)
  14. (tree)
  15. (metrics)
  16. (prediction)
  17. (split-quality-measure)
  18. ;; custom parallelism module
  19. (parallelism)
  20. ;; R6RS exception handling using conditions
  21. (rnrs exceptions)
  22. (rnrs conditions))
  23. ;; =======================
  24. ;; DECISION TREE ALGORITHM
  25. ;; =======================
  26. (define-public split-data
  27. (lambda (data index value)
  28. (receive (part1 part2)
  29. (dataset-partition (lambda (data-point)
  30. (< (data-point-get-col data-point index) value))
  31. data)
  32. (list part1 part2))))
  33. (define-public select-min-cost-split
  34. (lambda (. splits)
  35. (cond
  36. [(null? splits)
  37. (raise
  38. (condition
  39. (make-error)
  40. (make-message-condition
  41. "cannot get minimum cost split given no splits")
  42. (make-irritants-condition splits)
  43. (make-who-condition 'splits)))]
  44. [else
  45. (let iter ([remaining-splits (cdr splits)]
  46. [prev-min-split (car splits)])
  47. (cond
  48. [(null? remaining-splits) prev-min-split]
  49. [else
  50. (let ([next-split (car remaining-splits)])
  51. (if (< (split-cost prev-min-split) (split-cost next-split))
  52. (iter (cdr remaining-splits) prev-min-split)
  53. (iter (cdr remaining-splits) next-split)))]))])))
  54. (define-public get-best-split-for-column
  55. (lambda* (data
  56. label-column-index
  57. column-index
  58. #:key
  59. (split-quality-proc gini-index))
  60. "Calculate the best split value for the column of the data at the given
  61. index. This is achieved by going through all values in the column and
  62. calculating a split for each value and finding the one with the minimum cost."
  63. ;; FUTURE TODO: Allow for a heuristic, which selects a few split values at
  64. ;; random, or in other ways. Then check only the costs of those splits and
  65. ;; find the split of minimum cost in those few splits. The selected split
  66. ;; values do not necessarily have to be values of the split feature of any
  67. ;; data point in the data set. They could for example also be values
  68. ;; dividing the range of values of the split feature perfectly. Such a
  69. ;; heuristic might not result in a perfect tree, but would be much faster
  70. ;; then trying all the values the split feature takes on in the data set.
  71. (let ([initial-placeholder-split
  72. ;; The initial split is a dummy split, which has the worst cost
  73. ;; possible: Positively infinite cost.
  74. (make-split 0 +inf.0 (list '() '()) +inf.0)])
  75. ;; TODO: Parallelism: This is a place, where parallelism could be made use
  76. ;; of. Instead of going through all the split values of the column
  77. ;; sequentially, the split values can be processed in parallel. However,
  78. ;; it might be too much overhead to calculate the split for each split
  79. ;; value in a separate calculation unit. One might want to specify an
  80. ;; additional argument, which defines for how many split values each
  81. ;; calculation unit calculates the result and keep the overhead
  82. ;; configurable..
  83. (let iter-col-vals ([column-data (dataset-get-col data column-index)]
  84. [previous-best-split initial-placeholder-split])
  85. (cond
  86. [(dataset-column-empty? column-data) previous-best-split]
  87. [else
  88. (let* ([current-value (dataset-column-first column-data)]
  89. [current-subsets (split-data data
  90. column-index
  91. current-value)]
  92. [current-cost (split-quality-proc current-subsets label-column-index)])
  93. (iter-col-vals
  94. (dataset-column-rest column-data)
  95. (select-min-cost-split
  96. previous-best-split
  97. ;; FUTURE TODO: Here we are creating a Split record, which might
  98. ;; not be needed and thrown away after this iteration. An
  99. ;; optimization might be to not even create it, if the current
  100. ;; cost is higher than the cost of the previously best
  101. ;; split. However, always handling multiple values bloates the
  102. ;; code a little and the current implementation seems more
  103. ;; readable.
  104. (make-split column-index
  105. current-value
  106. current-subsets
  107. current-cost))))])))))
  108. (define-public get-best-split
  109. (lambda* (data
  110. feature-column-indices
  111. label-column-index
  112. #:key
  113. (split-quality-proc gini-index))
  114. (let ([max-col-index (- (data-point-length (dataset-first data)) 1)]
  115. [start-column-index 0]
  116. [initial-placeholder-split (make-split 0 +inf.0 (list '() '()) +inf.0)])
  117. (apply select-min-cost-split
  118. ;; NOTE: parallelism
  119. (run-in-parallel
  120. (lambda (column-index)
  121. (get-best-split-for-column data
  122. label-column-index
  123. column-index
  124. #:split-quality-proc split-quality-proc))
  125. feature-column-indices)))))
  126. (define-public fit
  127. (lambda* (#:key
  128. train-data
  129. (feature-column-indices '())
  130. label-column-index
  131. (max-depth 6)
  132. (min-data-points 12)
  133. (min-data-points-ratio 0.02)
  134. (min-impurity-split (expt 10 -7))
  135. (stop-at-no-impurity-improvement #t))
  136. (define all-data-length (dataset-length train-data))
  137. (define current-depth 1)
  138. #|
  139. STOP CRITERIA:
  140. - only one class in a subset (cannot be split any further and does not need to be split)
  141. - maximum tree depth reached
  142. - minimum number of data points in a subset
  143. - minimum ratio of data points in this subset
  144. |#
  145. (define all-same-label?
  146. (lambda (subset)
  147. (displayln "checking for stop condition: all-same-label?")
  148. ;; FUTURE TODO: Do no longer assume, that the label column is always an
  149. ;; integer or a number.
  150. (column-uniform? (dataset-get-col subset label-column-index) =)))
  151. (define insufficient-data-points-for-split?
  152. (lambda (subset)
  153. (displayln "checking for stop condition: insufficient-data-points-for-split?")
  154. (let ([number-of-data-points (dataset-length subset)])
  155. (or (<= number-of-data-points min-data-points)
  156. (< number-of-data-points 2)))))
  157. (define max-depth-reached?
  158. (lambda (current-depth)
  159. (displayln "checking for stop condition: max-depth-reached?")
  160. (>= current-depth max-depth)))
  161. (define insufficient-data-points-ratio-for-split?
  162. (lambda (subset)
  163. (displayln "checking for stop condition: insufficient-data-points-ratio-for-split?")
  164. (<= (/ (dataset-length subset) all-data-length) min-data-points-ratio)))
  165. (define no-improvement?
  166. (lambda (previous-split-impurity split-impurity)
  167. (displayln "checking for stop condition: no-improvement?")
  168. (and (<= previous-split-impurity split-impurity)
  169. stop-at-no-impurity-improvement)))
  170. (define insufficient-impurity?
  171. (lambda (impurity)
  172. (displayln "checking for stop condition: insufficient-impurity?")
  173. (< impurity min-impurity-split)))
  174. #|
  175. Here we do the recursive splitting.
  176. |#
  177. (define recursive-split
  178. (lambda (subset current-depth previous-split-impurity)
  179. (display "recursive split on depth: ") (displayln current-depth)
  180. ;; Before splitting further, we check for stopping early conditions.
  181. ;; TODO: Refactor this part. This cond form is way to big. Think of
  182. ;; something clever. TODO: Parallelism: This might be a place to use
  183. ;; parallelism at, to check for the stopping criteria in
  184. ;; parallel. However, I think they might not take that long to calculate
  185. ;; anyway and the question is, whether the overhead is worth it.
  186. (displayln "will check for stop conditions now")
  187. (cond
  188. [(max-depth-reached? current-depth)
  189. (displayln "STOPPING CONDITION: maximum depth")
  190. (displayln (string-append "INFO: still got "
  191. (number->string (dataset-length subset))
  192. " data points"))
  193. (make-leaf-node subset)]
  194. [(insufficient-data-points-for-split? subset)
  195. (displayln "STOPPING CONDITION: insuficient number of data points")
  196. (displayln (string-append "INFO: still got "
  197. (number->string (dataset-length subset))
  198. " data points"))
  199. (make-leaf-node subset)]
  200. [(insufficient-data-points-ratio-for-split? subset)
  201. (displayln "STOPPING CONDITION: insuficient ratio of data points")
  202. (displayln (string-append "INFO: still got "
  203. (number->string (dataset-length subset))
  204. " data points"))
  205. (make-leaf-node subset)]
  206. [(all-same-label? subset)
  207. (displayln "STOPPING CONDITION: all same label")
  208. (displayln (string-append "INFO: still got "
  209. (number->string (dataset-length subset))
  210. " data points"))
  211. (make-leaf-node subset)]
  212. [else
  213. (displayln (string-append "INFO: CONTINUING SPLITT: still got "
  214. (number->string (dataset-length subset))
  215. " data points"))
  216. ;; (display "input data for searching best split:") (displayln subset)
  217. (let* ([best-split
  218. (get-best-split subset
  219. feature-column-indices
  220. label-column-index
  221. #:split-quality-proc gini-index)])
  222. (cond
  223. [(no-improvement? previous-split-impurity (split-cost best-split))
  224. (displayln (string-append "STOPPING CONDITION: "
  225. "no improvement in impurity: previously: "
  226. (number->string previous-split-impurity) " "
  227. "now: "
  228. (number->string (split-cost best-split))))
  229. (make-leaf-node subset)]
  230. [(insufficient-impurity? previous-split-impurity)
  231. (displayln "STOPPING CONDITION: not enough impurity for splitting further")
  232. (make-leaf-node subset)]
  233. [else
  234. ;; Here are the recursive calls. This is not tail recursive, but
  235. ;; since the data structure itself is recursive and we only have
  236. ;; as many procedure calls as there are branches in the tree, it
  237. ;; is OK to not be tail recursive here.
  238. (let ([subsets
  239. ;; NOTE: parallelism
  240. (run-in-parallel (lambda (subset)
  241. (recursive-split subset
  242. (+ current-depth 1)
  243. (split-cost best-split)))
  244. (list (car (split-subsets best-split))
  245. (cadr (split-subsets best-split))))])
  246. (make-node subset
  247. (split-feature-index best-split)
  248. (split-value best-split)
  249. (car subsets)
  250. (cadr subsets)))]))])))
  251. (recursive-split train-data 1 1.0)))
  252. (define-public cross-validation-split
  253. (lambda* (dataset n-folds #:key (random-seed #f))
  254. (let* ([shuffled-dataset (shuffle-dataset dataset #:seed random-seed)]
  255. [number-of-data-points (dataset-length shuffled-dataset)]
  256. [fold-size
  257. (exact-floor (/ number-of-data-points n-folds))])
  258. (split-into-chunks-of-size-n shuffled-dataset
  259. (exact-ceiling
  260. (/ number-of-data-points n-folds))))))
  261. (define-public leave-one-out-k-folds
  262. (lambda (folds left-out-fold)
  263. (define leave-one-out-filter-procedure
  264. (lambda (fold)
  265. (not (equal? fold left-out-fold))))
  266. (filter leave-one-out-filter-procedure
  267. folds)))
  268. ;; evaluates the algorithm using cross validation split with n folds
  269. (define-public evaluate-algorithm
  270. (lambda* (#:key
  271. dataset
  272. n-folds
  273. feature-column-indices
  274. label-column-index
  275. (max-depth 6)
  276. (min-data-points 12)
  277. (min-data-points-ratio 0.02)
  278. (min-impurity-split (expt 10 -7))
  279. (stop-at-no-impurity-improvement #t)
  280. (random-seed #f))
  281. "Calculate a list of accuracy values, one value for each fold of a
  282. cross-validation split."
  283. (let ([folds
  284. (cross-validation-split dataset
  285. n-folds
  286. #:random-seed random-seed)])
  287. ;; NOTE: parallelism
  288. (run-in-parallel
  289. (lambda (fold)
  290. (let* ([train-set
  291. (fold-right append
  292. empty-dataset
  293. (leave-one-out-k-folds folds fold))]
  294. [test-set
  295. (map (lambda (data-point)
  296. (data-point-take-features data-point
  297. label-column-index))
  298. fold)]
  299. [actual-labels (dataset-get-col fold label-column-index)]
  300. [tree
  301. (fit #:train-data train-set
  302. #:feature-column-indices feature-column-indices
  303. #:label-column-index label-column-index
  304. #:max-depth max-depth
  305. #:min-data-points min-data-points
  306. #:min-data-points-ratio min-data-points-ratio
  307. #:min-impurity-split min-impurity-split
  308. #:stop-at-no-impurity-improvement stop-at-no-impurity-improvement)]
  309. [predicted-labels
  310. (predict-dataset tree test-set label-column-index)])
  311. (accuracy-metric actual-labels predicted-labels)))
  312. folds))))