list-of-list-solution.rkt 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #lang racket
  2. #|
  3. Attribution:
  4. This implementation of decision trees in Racket was written by Daniel Prager and
  5. was originally shared at:
  6. https://groups.google.com/forum/#!topic/racket-users/cPuTr8lrXCs
  7. With permission it was added to the project.
  8. |#
  9. (define (string->data s [sep " "])
  10. (for/list ([line (in-list (string-split s #rx"\r?\n"))])
  11. (map string->number (string-split line sep))))
  12. (define banknote-data
  13. (string->data (file->string "data_banknote_authentication.csv") ","))
  14. (define test-data
  15. (string->data
  16. "2.771244718 1.784783929 0
  17. 1.728571309 1.169761413 0
  18. 3.678319846 2.81281357 0
  19. 3.961043357 2.61995032 0
  20. 2.999208922 2.209014212 0
  21. 7.497545867 3.162953546 1
  22. 9.00220326 3.339047188 1
  23. 7.444542326 0.476683375 1
  24. 10.12493903 3.234550982 1
  25. 6.642287351 3.319983761 1"))
  26. (define (make-split rows index value)
  27. (define-values (left right)
  28. (for/fold ([left null] [right null])
  29. ([row (in-list rows)])
  30. (if (< (list-ref row index) value)
  31. (values (cons row left) right)
  32. (values left (cons row right)))))
  33. (list left right))
  34. (define (gini-coefficient splits)
  35. (for/sum ([split (in-list splits)])
  36. (define n (* 1.0 (length split)))
  37. (define (g v) (* (/ v n) (- 1.0 (/ v n))))
  38. (if (zero? n)
  39. 0
  40. (let ([m (for/sum ([row (in-list split)] #:when (zero? (last row)))
  41. 1)])
  42. (+ (g m) (g (- n m)))))))
  43. (define (get-split rows)
  44. (define-values (best index value _)
  45. (for*/fold ([best null] [i -1] [v -1] [score 999])
  46. ([index (in-range (sub1 (length (first rows))))]
  47. [row (in-list rows)])
  48. (let* ([value (list-ref row index)]
  49. [s (make-split rows index value)]
  50. [gini (gini-coefficient s)])
  51. (if (< gini score)
  52. (values s index value gini)
  53. (values best i v score)))))
  54. (list index value best))
  55. (define (to-terminal group)
  56. (define zeros (count (λ (row) (zero? (last row))) group))
  57. (if (> zeros (- (length group) zeros)) 0 1))
  58. (define (split node max-depth min-size depth)
  59. (match-define (list index value (list left right)) node)
  60. (define (split-if-small branch)
  61. (if (<= (length branch) min-size)
  62. (to-terminal branch)
  63. (split (get-split branch) max-depth min-size (add1 depth))))
  64. (cond [(null? left) (to-terminal right)]
  65. [(null? right) (to-terminal left)]
  66. [(>= depth max-depth) (list index value
  67. (to-terminal left) (to-terminal right))]
  68. [else (list index value
  69. (split-if-small left) (split-if-small right))]))
  70. (define (build-tree rows max-depth min-size)
  71. (split (get-split rows) max-depth min-size 1))
  72. (define (predict node row)
  73. (if (list? node)
  74. (match-let ([(list index value left right) node])
  75. (predict (if (< (list-ref row index) value)
  76. left
  77. right)
  78. row))
  79. node))
  80. (define (check-model model validation-set)
  81. (/ (count (λ (row) (= (predict model row) (last row)))
  82. validation-set)
  83. (length validation-set)
  84. 1.0))
  85. ;(define test-model (build-tree test-data 1 1))
  86. ;(for/list ([row (in-list test-data)])
  87. ; (list row (predict test-model row)))
  88. (define data (shuffle banknote-data))
  89. (define model (time (build-tree (take data 274) 5 10)))
  90. model
  91. (check-model model (drop data 274))
  92. (random-seed 12345)
  93. (define data2 (shuffle banknote-data))
  94. (time
  95. (void
  96. (build-tree (take data2 274) 5 10)))
  97. (time
  98. (for ([i (in-range 20)])
  99. (build-tree (take data2 274) 5 10)))