123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567 |
- (use-modules
- ;; SRFI 64 for unit testing facilities
- (srfi srfi-64)
- ;; SRFI 8 for `receive` form
- (srfi srfi-8)
- ;; utils - the code to be tested
- (decision-tree)
- ;; Utilities for testing
- (utils test)
- ;; Dependencies for testing the code to be tested
- (dataset)
- (data-point)
- (tree)
- (pruning))
- (define TEST-DATA
- (list #(2.771244718 1.784783929 0)
- #(1.728571309 1.169761413 0)
- #(3.678319846 2.81281357 0)
- #(3.961043357 2.61995032 0)
- #(2.999208922 2.209014212 0)
- #(7.497545867 3.162953546 1)
- #(9.00220326 3.339047188 1)
- #(7.444542326 0.476683375 1)
- #(10.12493903 3.234550982 1)
- #(6.642287351 3.319983761 1)))
- (define PRECISION (expt 10 -9))
- (test-begin "pruning-test")
- (test-group
- "count-leaves"
- (test-equal
- 2
- (count-leaves (make-node (list #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 1
- 1.1
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))
- (make-leaf-node (list #(2.3 1.1 0)
- #(2.0 1.1 0))))))
- (test-equal
- 3
- (count-leaves (make-node (list #(2.3 1.1 3.0 0)
- #(2.0 1.1 3.0 0)
- #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 1
- 1.1
- (make-node (list #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 2
- 4.0
- (make-leaf-node (list #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1)))
- (make-leaf-node (list #(2.3 1.0 4.0 0))))
- (make-leaf-node (list #(2.3 1.1 3.0 0)
- #(2.0 1.1 3.0 0)))))))
- (test-group
- "get-last-split-nodes"
- (test-equal
- (list (make-node (list #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 2
- 4.0
- (make-leaf-node (list #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1)))
- (make-leaf-node (list #(2.3 1.0 4.0 0)))))
- (get-last-split-nodes
- (make-node (list #(2.3 1.1 3.0 0)
- #(2.0 1.1 3.0 0)
- #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 1
- 1.1
- (make-node (list #(2.3 1.0 4.0 0)
- #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1))
- 2
- 4.0
- (make-leaf-node (list #(2.0 1.0 3.0 1)
- #(2.3 1.0 3.0 1)
- #(2.0 1.0 3.0 1)
- #(2.4 1.0 3.0 1)))
- (make-leaf-node (list #(2.3 1.0 4.0 0))))
- (make-leaf-node (list #(2.3 1.1 3.0 0)
- #(2.0 1.1 3.0 0)))))))
- (test-group
- "prune-node-from-tree"
- (let* ([tree (make-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0) ;
- #(2.3 1.1 0)
- #(2.0 1.1 0) ;;
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 0
- 2.0
- (make-leaf-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0)))
- (make-node (list #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 1
- 1.1
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))
- (make-leaf-node (list #(2.3 1.1 0)
- #(2.0 1.1 0)))))]
- [split-node (make-node (list #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 1
- 1.1
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))
- (make-leaf-node (list #(2.3 1.1 0)
- #(2.0 1.1 0))))]
- [tranformed-node (make-leaf-node-from-split-node split-node)])
- (test-equal
- (make-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0) ;
- #(2.3 1.1 0)
- #(2.0 1.1 0) ;;
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))
- 0
- 2.0
- (make-leaf-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0)))
- tranformed-node)
- (prune-node-from-tree tree split-node))))
- (test-group
- "select-better-tree"
- (let ([tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.20 0)
- #(2.1 0.21 0)
- #(2.2 0.22 0)
- #(2.3 0.23 0)
- #(2.4 0.24 0)
- #(2.5 0.25 0)
- #(2.6 0.26 0)
- #(2.7 0.27 0)
- #(3.0 0.10 0)
- #(3.0 0.20 1))
- 0
- 3.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.20 0)
- #(2.1 0.21 0)
- #(2.2 0.22 0)
- #(2.3 0.23 0)
- #(2.4 0.24 0)
- #(2.5 0.25 0)
- #(2.6 0.26 0)
- #(2.7 0.27 0)))
- (make-node (list #(3.0 0.10 0) ; the node, which will be pruned away
- #(3.0 0.20 1))
- 1
- 0.2
- (make-leaf-node (list #(3.0 0.10 0)))
- (make-leaf-node (list #(3.0 0.20 1)))))]
- [pruned-tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.20 0)
- #(2.1 0.21 0)
- #(2.2 0.22 0)
- #(2.3 0.23 0)
- #(2.4 0.24 0)
- #(2.5 0.25 0)
- #(2.6 0.26 0)
- #(2.7 0.27 0)
- #(3.0 0.10 0)
- #(3.0 0.20 1))
- 0
- 3.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.20 0)
- #(2.1 0.21 0)
- #(2.2 0.22 0)
- #(2.3 0.23 0)
- #(2.4 0.24 0)
- #(2.5 0.25 0)
- #(2.6 0.26 0)
- #(2.7 0.27 0)))
- (make-leaf-node (list #(3.0 0.10 0) ; the pruned node
- #(3.0 0.20 1))))]
- ;; the pruning set is only +0.01 and +0.001 in all rows
- ;; so that the values do not pass decision boundaries
- [pruning-set (list #(1.01 0.101 0)
- #(1.11 0.111 0)
- #(1.21 0.121 0)
- #(1.31 0.131 0)
- #(1.41 0.141 0)
- #(1.51 0.151 0)
- #(1.61 0.161 0)
- #(1.71 0.171 0)
- #(1.81 0.181 0)
- #(1.91 0.191 0)
- #(2.01 0.201 0)
- #(2.11 0.211 0)
- #(2.21 0.221 0)
- #(2.31 0.231 0)
- #(2.41 0.241 0)
- #(2.51 0.251 0)
- #(2.61 0.261 0)
- #(2.71 0.271 0)
- #(3.01 0.101 0)
- #(3.01 0.201 1))]
- [feature-column-indices (list 0 1)]
- [label-column-index 2]
- ;; 6% classification error tolerance,
- ;; so that 1 of 20 data points misclassification does not matter
- [accuracy-tolerance 0.06])
- ;; since the 5% improvement are below the tolerance for lost accuracy
- ;; when pruning, the tree should indeed be pruned.
- (test-equal
- pruned-tree
- (select-better-tree tree
- pruned-tree
- pruning-set
- feature-column-indices
- label-column-index
- accuracy-tolerance))
- ;; now try with a lower tolerance, which should not allow pruning
- (test-equal
- tree
- (select-better-tree tree
- pruned-tree
- pruning-set
- feature-column-indices
- label-column-index
- 0.04))))
- (test-group
- "prune-with-pruning-set"
- (let ([tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 0
- 2.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)))
- (make-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 1
- 0.1
- (make-node (list #(3.0 0.00 1)
- #(3.0 0.01 0))
- 1
- 0.01
- (make-leaf-node (list #(3.0 0.00 1)))
- (make-leaf-node (list #(3.0 0.01 0))))
- (make-leaf-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)))))]
- [pruned-tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 0
- 2.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)))
- (make-leaf-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))))]
- [pruning-set (list #(1.01 0.101 0)
- #(1.11 0.111 0)
- #(1.21 0.121 0)
- #(1.31 0.131 0)
- #(1.41 0.141 0)
- #(1.51 0.151 0)
- #(1.61 0.161 0)
- #(1.71 0.171 0)
- #(1.81 0.181 0)
- #(1.91 0.191 0)
- #(2.01 0.101 1)
- #(2.11 0.111 1)
- #(2.21 0.121 1)
- #(2.31 0.131 1)
- #(2.41 0.141 1)
- #(2.51 0.151 1)
- #(2.61 0.161 1)
- #(2.71 0.171 1)
- #(3.01 0.001 1)
- #(3.01 0.011 0))])
- (test-equal
- pruned-tree
- (prune-with-pruning-set tree
- pruning-set
- (list 0 1)
- 2
- #:tolerance 0.06)))
- (let ([tree (make-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)
- #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 0
- 2.0
- (make-leaf-node (list #(1.0 0.10 0)
- #(1.1 0.11 0)
- #(1.2 0.12 0)
- #(1.3 0.13 0)
- #(1.4 0.14 0)
- #(1.5 0.15 0)
- #(1.6 0.16 0)
- #(1.7 0.17 0)
- #(1.8 0.18 0)
- #(1.9 0.19 0)))
- (make-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)
- #(3.0 0.00 1)
- #(3.0 0.01 0))
- 1
- 0.1
- (make-node (list #(3.0 0.00 1)
- #(3.0 0.01 0))
- 1
- 0.01
- (make-leaf-node (list #(3.0 0.00 1)))
- (make-leaf-node (list #(3.0 0.01 0))))
- (make-leaf-node (list #(2.0 0.10 1)
- #(2.1 0.11 1)
- #(2.2 0.12 1)
- #(2.3 0.13 1)
- #(2.4 0.14 1)
- #(2.5 0.15 1)
- #(2.6 0.16 1)
- #(2.7 0.17 1)))))]
- [pruning-set (list #(1.01 0.101 0)
- #(1.11 0.111 0)
- #(1.21 0.121 0)
- #(1.31 0.131 0)
- #(1.41 0.141 0)
- #(1.51 0.151 0)
- #(1.61 0.161 0)
- #(1.71 0.171 0)
- #(1.81 0.181 0)
- #(1.91 0.191 0)
- #(2.01 0.101 1)
- #(2.11 0.111 1)
- #(2.21 0.121 1)
- #(2.31 0.131 1)
- #(2.41 0.141 1)
- #(2.51 0.151 1)
- #(2.61 0.161 1)
- #(2.71 0.171 1)
- #(3.01 0.001 1)
- #(3.01 0.011 0))])
- (test-equal
- tree
- (prune-with-pruning-set tree
- pruning-set
- (list 0 1)
- 2
- #:tolerance 0.04))))
- ;; TODO: missing test: traverse-collect-last-split-nodes
- (test-end "pruning-test")
|