pruning.scm 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. (define-module (pruning))
  2. (use-modules
  3. ((srfi srfi-1) #:prefix srfi1:)
  4. (tree)
  5. (metrics)
  6. (dataset)
  7. (prediction)
  8. (utils list-utils)
  9. (utils display-utils))
  10. (define-public count-leaves
  11. (lambda (tree)
  12. (cond [(leaf-node? tree) 1]
  13. [else (+ (count-leaves (node-left tree))
  14. (count-leaves (node-right tree)))])))
  15. (define-public traverse-collect-last-split-nodes
  16. (lambda (subtree)
  17. (cond
  18. [(leaf-node? subtree) empty-dataset]
  19. [(last-split-node? subtree) (list subtree)]
  20. [(leaf-node? (node-left subtree))
  21. (traverse-collect-last-split-nodes (node-right subtree))]
  22. [(leaf-node? (node-right subtree))
  23. (traverse-collect-last-split-nodes (node-left subtree))]
  24. [else
  25. (append (traverse-collect-last-split-nodes (node-left subtree))
  26. (traverse-collect-last-split-nodes (node-right subtree)))])))
  27. (define-public get-last-split-nodes
  28. (lambda (tree)
  29. (flatten (traverse-collect-last-split-nodes tree))))
  30. (define-public select-better-tree
  31. (lambda (tree
  32. pruned-tree
  33. pruning-set
  34. feature-column-indices
  35. label-column-index
  36. accuracy-tolerance)
  37. "Prune the tree so that the accuracy of the tree is best for the given
  38. pruning set."
  39. (let ([actual-labels
  40. (dataset-get-col pruning-set label-column-index)]
  41. [tree-predicted-labels
  42. (predict-dataset tree pruning-set label-column-index)]
  43. [pruned-tree-predicted-labels
  44. (predict-dataset pruned-tree pruning-set label-column-index)])
  45. (let ([tree-accuracy
  46. (accuracy-metric actual-labels tree-predicted-labels)]
  47. [pruned-tree-accuracy
  48. (accuracy-metric actual-labels pruned-tree-predicted-labels)])
  49. #;(displayln (string-append "accuracy tree: " (number->string tree-accuracy)))
  50. #;(displayln (string-append "accuracy pruned-tree: " (number->string pruned-tree-accuracy)))
  51. (cond
  52. [(< (abs (- tree-accuracy pruned-tree-accuracy)) accuracy-tolerance)
  53. pruned-tree]
  54. [else tree])))))
  55. (define-public prune-node-from-tree
  56. (lambda (tree split-node)
  57. (cond [(leaf-node? tree) tree]
  58. [(equal? tree split-node)
  59. (make-leaf-node-from-split-node tree)]
  60. [else
  61. (make-node
  62. ;; copy all data
  63. (node-data tree)
  64. (node-split-feature-index tree)
  65. (node-split-value tree)
  66. ;; prune subtrees
  67. ;; FUTURE TODO: This is up for multicore optimization. Each subtree
  68. ;; pruning can run as a separate job.
  69. (prune-node-from-tree (node-left tree) split-node)
  70. (prune-node-from-tree (node-right tree) split-node))])))
  71. (define-public prune-with-pruning-set
  72. (lambda* (tree
  73. pruning-set
  74. feature-column-indices
  75. label-column-index
  76. #:key
  77. (tolerance 0.0))
  78. (define iter-split-nodes
  79. (lambda (tree remaining-split-nodes)
  80. (cond [(null? remaining-split-nodes) tree]
  81. [else
  82. #;(displayln "REMAINING-SPLIT-NODES:")
  83. #;(displayln remaining-split-nodes)
  84. (iter-split-nodes
  85. (select-better-tree tree
  86. (prune-node-from-tree tree
  87. (srfi1:first remaining-split-nodes))
  88. pruning-set
  89. feature-column-indices
  90. label-column-index
  91. tolerance)
  92. (cdr remaining-split-nodes))])))
  93. (define iter-trees
  94. (lambda (tree tree-leaves#)
  95. (let* ([pruned-tree (iter-split-nodes tree (get-last-split-nodes tree))]
  96. [pruned-tree-leaves# (count-leaves pruned-tree)])
  97. ;;(displayln "tree: ") (displayln tree)
  98. ;;(displayln "pruned tree: ") (displayln pruned-tree)
  99. (cond
  100. ;; in the previous call to iter-split-nodes leaves were removed
  101. ;; by pruning the tree. This means that all last split nodes cannot
  102. ;; be removed and thus we finished the pruning process.
  103. [(= pruned-tree-leaves# tree-leaves#)
  104. (displayln "STOPPING CONDITION (PRUNING): pruning further would decrease accuracy beyong tolerance")
  105. tree]
  106. ;; in the last call to iter-split-nodes leaves were removed,
  107. ;; so there is at least one new last split node and we need
  108. ;; to try to prune that
  109. [else
  110. (displayln "CONTINUING PRUNING: tree lost nodes in previous iteration of pruning")
  111. (iter-trees pruned-tree pruned-tree-leaves#)]))))
  112. (iter-trees tree (count-leaves tree))))