123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360 |
- (use-modules
-
- (srfi srfi-64)
-
- (srfi srfi-8)
-
- (decision-tree)
-
- (utils test)
-
- (dataset)
- (metrics)
- (pruning)
- (prediction)
- (data-point)
- (tree))
- (define TEST-DATA
- (list #(2.771244718 1.784783929 0)
- #(1.728571309 1.169761413 0)
- #(3.678319846 2.81281357 0)
- #(3.961043357 2.61995032 0)
- #(2.999208922 2.209014212 0)
- #(7.497545867 3.162953546 1)
- #(9.00220326 3.339047188 1)
- #(7.444542326 0.476683375 1)
- #(10.12493903 3.234550982 1)
- #(6.642287351 3.319983761 1)))
- (define PRECISION (expt 10 -9))
- (test-begin "decision-tree-test")
- (test-group
- "split-data"
-
- (test-equal "split-data-1"
- (list (list #(1.0 1.0 1.0 1.0 0)
- #(1.2 1.0 1.0 1.0 0)
- #(1.4 1.0 1.0 1.0 0))
- (list #(1.6 1.0 1.0 1.0 0)
- #(1.8 1.0 1.0 1.0 0)
- #(2.0 1.0 1.0 1.0 0)))
- (split-data (list #(1.0 1.0 1.0 1.0 0)
- #(1.2 1.0 1.0 1.0 0)
- #(1.4 1.0 1.0 1.0 0)
- #(1.6 1.0 1.0 1.0 0)
- #(1.8 1.0 1.0 1.0 0)
- #(2.0 1.0 1.0 1.0 0))
- 0
- 1.5))
-
- (test-equal "split-data-2"
- (list (list #(1.0 1.0 1.0 1.0 0)
- #(1.4 1.0 1.0 1.0 0)
- #(1.8 1.0 1.0 1.0 0)
- #(2.0 2.0 1.0 1.0 0))
- (list #(1.2 4.0 1.0 1.0 0)
- #(1.6 3.0 1.0 1.0 0)))
- (split-data (list #(1.0 1.0 1.0 1.0 0)
- #(1.2 4.0 1.0 1.0 0)
- #(1.4 1.0 1.0 1.0 0)
- #(1.6 3.0 1.0 1.0 0)
- #(1.8 1.0 1.0 1.0 0)
- #(2.0 2.0 1.0 1.0 0))
- 1
- 2.5)))
- (test-group
- "get-best-split"
- (let ([test-data (list #(2.771244718 1.784783929 0)
- #(1.728571309 1.169761413 0)
- #(3.678319846 2.81281357 0)
- #(3.961043357 2.61995032 0)
- #(2.999208922 2.209014212 0)
- #(7.497545867 3.162953546 1)
- #(9.00220326 3.339047188 1)
- #(7.444542326 0.476683375 1)
- #(10.12493903 3.234550982 1)
- #(6.642287351 3.319983761 1))]
- [feature-column-indices (list 0 1)]
- [label-column-index 2])
-
- (test-equal "get-best-split-1"
-
-
-
-
-
-
-
-
- (make-split 0
- 6.642287351
- (list
-
- (list #(2.771244718 1.784783929 0)
- #(1.728571309 1.169761413 0)
- #(3.678319846 2.81281357 0)
- #(3.961043357 2.61995032 0)
- #(2.999208922 2.209014212 0))
-
- (list #(7.497545867 3.162953546 1)
- #(9.00220326 3.339047188 1)
- #(7.444542326 0.476683375 1)
- #(10.12493903 3.234550982 1)
- #(6.642287351 3.319983761 1)))
- 0.0)
- (get-best-split test-data
- feature-column-indices
- label-column-index))))
- (test-group
- "fit"
- (let ([test-data (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))]
- [feature-column-indices (list 0 1)]
- [label-column-index 2])
- (test-equal
- (let ([best-split (get-best-split test-data (list 0 1) 2)])
- (make-node test-data
- (split-feature-index best-split)
- (split-value best-split)
- (make-leaf-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0)))
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))))
- (fit #:train-data test-data
- #:feature-column-indices (list 0 1)
- #:label-column-index 2
- #:max-depth 2
- #:min-data-points 4
- #:min-data-points-ratio 0.02)))
- (let* ([test-data (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0)
- #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))]
- [best-split (get-best-split test-data (list 0 1) 2)])
- (test-equal
- (make-node test-data
- (split-feature-index best-split)
- (split-value best-split)
- (make-leaf-node (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0)))
- (let* ([subset (list #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))]
- [best-split (get-best-split subset (list 0 1) 2)])
- (make-node subset
- (split-feature-index best-split)
- (split-value best-split)
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))
- (make-leaf-node (list #(2.3 1.1 0)
- #(2.0 1.1 0))))))
- (fit #:train-data test-data
- #:feature-column-indices (list 0 1)
- #:label-column-index 2
- #:max-depth 3
- #:min-data-points 2
- #:min-data-points-ratio 0.02)))
- (let* ([test-data (list #(2.3 1.1 0)
- #(2.0 1.1 0)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1))]
- [best-split (get-best-split test-data (list 0 1) 2)])
- (test-equal
- (make-node test-data
- (split-feature-index best-split)
- (split-value best-split)
- (make-leaf-node (list #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.3 1.0 1)
- #(2.0 1.0 1)
- #(2.4 1.0 1)))
- (make-leaf-node (list #(2.3 1.1 0)
- #(2.0 1.1 0))))
- (fit #:train-data test-data
- #:feature-column-indices (list 0 1)
- #:label-column-index 2
- #:max-depth 3
- #:min-data-points 2
- #:min-data-points-ratio 0.02))))
- (test-group
- "column-uniform?"
- (test-assert "column-uniform? of empty column should be true"
- (column-uniform? empty-dataset =))
- (test-assert "column-uniform? of uniform column should result in true -- 1"
- (column-uniform? (list 1 1 1) =))
- (test-assert "column-uniform? of uniform column should result in true -- 2"
- (column-uniform?
- (dataset-get-col
- (list #(1.0 1.0 0)
- #(1.2 1.0 0)
- #(1.1 1.0 0)
- #(1.4 1.0 0)
- #(1.2 1.0 0)
- #(1.2 1.0 0))
- 2)
- =))
- (test-assert "column-uniform? of non-uniform column should result in false"
- (not
- (column-uniform? (list 1 2 3) =))))
- (test-group
- "dataset-partition"
- (test-equal "dataset-partition should split at given value of specified column"
- (list (list #(2.3 1.0 0)
- #(2.0 1.0 0)
- #(2.3 1.0 0)
- #(2.0 1.0 0)
- #(2.4 1.0 0))
- (list #(2.3 1.1 1)
- #(2.0 1.1 1)))
- (receive (matching not-matching)
- (dataset-partition (lambda (data-point)
- (= (data-point-get-col data-point 2) 0))
- (list #(2.3 1.1 1)
- #(2.0 1.1 1)
- #(2.3 1.0 0)
- #(2.0 1.0 0)
- #(2.3 1.0 0)
- #(2.0 1.0 0)
- #(2.4 1.0 0)))
- (list matching not-matching))))
- (test-group
- "cross-validation-split"
- (test-equal
- (list '(6 19 13 0 10)
- '(2 16 3 17 4)
- '(11 8 7 14 1)
- '(5 12 18 9 15))
- (cross-validation-split '(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19)
- 4
- #:random-seed 12345)))
- (test-group
- "leave-one-out-k-folds"
- (test-equal
- (list (list #(1 1)
- #(1 1)
- #(1 1)
- #(1 1))
- (list #(2 2)
- #(2 2)
- #(2 2)
- #(2 2))
- (list #(4 4)
- #(4 4)
- #(4 4)
- #(4 4)))
- (leave-one-out-k-folds (list (list #(1 1)
- #(1 1)
- #(1 1)
- #(1 1))
- (list #(2 2)
- #(2 2)
- #(2 2)
- #(2 2))
- (list #(3 3)
- #(3 3)
- #(3 3)
- #(3 3))
- (list #(4 4)
- #(4 4)
- #(4 4)
- #(4 4)))
- (list #(3 3)
- #(3 3)
- #(3 3)
- #(3 3)))))
- (test-group
- "evaluate-algorithm "
- (test-equal
- 4
- (length
- (evaluate-algorithm
- #:dataset TEST-DATA
- #:n-folds 4
- #:feature-column-indices (list 0 1)
- #:label-column-index 2
- #:max-depth 3
- #:min-data-points 4
- #:min-data-points-ratio 0.02
- #:min-impurity-split (expt 10 -7)
- #:stop-at-no-impurity-improvement #t
- #:random-seed 0)))
-
- )
- (test-end "decision-tree-test")
|