|
- (use-modules
- ;; SRFI 64 for unit testing facilities
- (srfi srfi-64)
- ;; SRFI 8 for `receive` form
- (srfi srfi-8)
- ;; utils - the code to be tested
- (decision-tree)
- ;; Utilities for testing
- (utils test)
- ;; Dependencies for testing the code to be tested
- (dataset)
- (data-point)
- (tree)
- (pruning))
- (define TEST-DATA
- (list #(2.771244718 1.784783929 0)
- #(1.728571309 1.169761413 0)
- #(3.678319846 2.81281357 0)
- #(3.961043357 2.61995032 0)
- #(2.999208922 2.209014212 0)
- #(7.497545867 3.162953546 1)
- #(9.00220326 3.339047188 1)
- #(7.444542326 0.476683375 1)
- #(10.12493903 3.234550982 1)
- #(6.642287351 3.319983761 1)))
- (define PRECISION (expt 10 -9))
- (test-begin "pruning-test")
- (test-group
- "count-leaves"
- (test-equal
- 2
- (count-leaves (make-node (list #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 1
- 1.1
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))
- (make-leaf-node (list #(2.3 1.1 0)
- #(2.0 1.1 0))))))
- (test-equal
- 3
- (count-leaves (make-node (list #(2.3 1.1 3.0 0)
- #(2.0 1.1 3.0 0)
- #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 1
- 1.1
- (make-node (list #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 2
- 4.0
- (make-leaf-node (list #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1)))
- (make-leaf-node (list #(2.3 1.0 4.0 0))))
- (make-leaf-node (list #(2.3 1.1 3.0 0)
- #(2.0 1.1 3.0 0)))))))
- (test-group
- "get-last-split-nodes"
- (test-equal
- (list (make-node (list #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 2
- 4.0
- (make-leaf-node (list #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1)))
- (make-leaf-node (list #(2.3 1.0 4.0 0)))))
- (get-last-split-nodes
- (make-node (list #(2.3 1.1 3.0 0)
- #(2.0 1.1 3.0 0)
- #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 1
- 1.1
- (make-node (list #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 2
- 4.0
- (make-leaf-node (list #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1)))
- (make-leaf-node (list #(2.3 1.0 4.0 0))))
- (make-leaf-node (list #(2.3 1.1 3.0 0)
- #(2.0 1.1 3.0 0)))))))
- (test-group
- "prune-node-from-tree"
- (let* ([tree (make-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0) ;
- #(2.3 1.1 0)
- #(2.0 1.1 0) ;;
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 0
- 2.0
- (make-leaf-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0)))
- (make-node (list #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 1
- 1.1
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))
- (make-leaf-node (list #(2.3 1.1 0)
- #(2.0 1.1 0)))))]
- [split-node (make-node (list #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 1
- 1.1
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))
- (make-leaf-node (list #(2.3 1.1 0)
- #(2.0 1.1 0))))]
- [tranformed-node (make-leaf-node-from-split-node split-node)])
- (test-equal
- (make-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0) ;
- #(2.3 1.1 0)
- #(2.0 1.1 0) ;;
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 0
- 2.0
- (make-leaf-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0)))
- tranformed-node)
- (prune-node-from-tree tree split-node))))
- (test-group
- "select-better-tree"
- (let ([tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.20 0)
- #(2.1 0.21 0)
- #(2.2 0.22 0)
- #(2.3 0.23 0)
- #(2.4 0.24 0)
- #(2.5 0.25 0)
- #(2.6 0.26 0)
- #(2.7 0.27 0)
- #(3.0 0.10 0)
- #(3.0 0.20 1))
- 0
- 3.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.20 0)
- #(2.1 0.21 0)
- #(2.2 0.22 0)
- #(2.3 0.23 0)
- #(2.4 0.24 0)
- #(2.5 0.25 0)
- #(2.6 0.26 0)
- #(2.7 0.27 0)))
- (make-node (list #(3.0 0.10 0) ; the node, which will be pruned away
- #(3.0 0.20 1))
- 1
- 0.2
- (make-leaf-node (list #(3.0 0.10 0)))
- (make-leaf-node (list #(3.0 0.20 1)))))]
- [pruned-tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.20 0)
- #(2.1 0.21 0)
- #(2.2 0.22 0)
- #(2.3 0.23 0)
- #(2.4 0.24 0)
- #(2.5 0.25 0)
- #(2.6 0.26 0)
- #(2.7 0.27 0)
- #(3.0 0.10 0)
- #(3.0 0.20 1))
- 0
- 3.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.20 0)
- #(2.1 0.21 0)
- #(2.2 0.22 0)
- #(2.3 0.23 0)
- #(2.4 0.24 0)
- #(2.5 0.25 0)
- #(2.6 0.26 0)
- #(2.7 0.27 0)))
- (make-leaf-node (list #(3.0 0.10 0) ; the pruned node
- #(3.0 0.20 1))))]
- ;; the pruning set is only +0.01 and +0.001 in all rows
- ;; so that the values do not pass decision boundaries
- [pruning-set (list #(1.01 0.101 0)
- #(1.11 0.111 0)
- #(1.21 0.121 0)
- #(1.31 0.131 0)
- #(1.41 0.141 0)
- #(1.51 0.151 0)
- #(1.61 0.161 0)
- #(1.71 0.171 0)
- #(1.81 0.181 0)
- #(1.91 0.191 0)
- #(2.01 0.201 0)
- #(2.11 0.211 0)
- #(2.21 0.221 0)
- #(2.31 0.231 0)
- #(2.41 0.241 0)
- #(2.51 0.251 0)
- #(2.61 0.261 0)
- #(2.71 0.271 0)
- #(3.01 0.101 0)
- #(3.01 0.201 1))]
- [feature-column-indices (list 0 1)]
- [label-column-index 2]
- ;; 6% classification error tolerance,
- ;; so that 1 of 20 data points misclassification does not matter
- [accuracy-tolerance 0.06])
- ;; since the 5% improvement are below the tolerance for lost accuracy
- ;; when pruning, the tree should indeed be pruned.
- (test-equal
- pruned-tree
- (select-better-tree tree
- pruned-tree
- pruning-set
- feature-column-indices
- label-column-index
- accuracy-tolerance))
- ;; now try with a lower tolerance, which should not allow pruning
- (test-equal
- tree
- (select-better-tree tree
- pruned-tree
- pruning-set
- feature-column-indices
- label-column-index
- 0.04))))
- (test-group
- "prune-with-pruning-set"
- (let ([tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 0
- 2.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)))
- (make-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 1
- 0.1
- (make-node (list #(3.0 0.00 1)
- #(3.0 0.01 0))
- 1
- 0.01
- (make-leaf-node (list #(3.0 0.00 1)))
- (make-leaf-node (list #(3.0 0.01 0))))
- (make-leaf-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)))))]
- [pruned-tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 0
- 2.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)))
- (make-leaf-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))))]
- [pruning-set (list #(1.01 0.101 0)
- #(1.11 0.111 0)
- #(1.21 0.121 0)
- #(1.31 0.131 0)
- #(1.41 0.141 0)
- #(1.51 0.151 0)
- #(1.61 0.161 0)
- #(1.71 0.171 0)
- #(1.81 0.181 0)
- #(1.91 0.191 0)
- #(2.01 0.101 1)
- #(2.11 0.111 1)
- #(2.21 0.121 1)
- #(2.31 0.131 1)
- #(2.41 0.141 1)
- #(2.51 0.151 1)
- #(2.61 0.161 1)
- #(2.71 0.171 1)
- #(3.01 0.001 1)
- #(3.01 0.011 0))])
- (test-equal
- pruned-tree
- (prune-with-pruning-set tree
- pruning-set
- (list 0 1)
- 2
- #:tolerance 0.06)))
- (let ([tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 0
- 2.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)))
- (make-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 1
- 0.1
- (make-node (list #(3.0 0.00 1)
- #(3.0 0.01 0))
- 1
- 0.01
- (make-leaf-node (list #(3.0 0.00 1)))
- (make-leaf-node (list #(3.0 0.01 0))))
- (make-leaf-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)))))]
- [pruning-set (list #(1.01 0.101 0)
- #(1.11 0.111 0)
- #(1.21 0.121 0)
- #(1.31 0.131 0)
- #(1.41 0.141 0)
- #(1.51 0.151 0)
- #(1.61 0.161 0)
- #(1.71 0.171 0)
- #(1.81 0.181 0)
- #(1.91 0.191 0)
- #(2.01 0.101 1)
- #(2.11 0.111 1)
- #(2.21 0.121 1)
- #(2.31 0.131 1)
- #(2.41 0.141 1)
- #(2.51 0.151 1)
- #(2.61 0.161 1)
- #(2.71 0.171 1)
- #(3.01 0.001 1)
- #(3.01 0.011 0))])
- (test-equal
- tree
- (prune-with-pruning-set tree
- pruning-set
- (list 0 1)
- 2
- #:tolerance 0.04))))
- ;; TODO: missing test: traverse-collect-last-split-nodes
- (test-end "pruning-test")
|