run.scm 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. (use-modules
  2. (utils csv)
  3. (decision-tree)
  4. (dataset)
  5. (tree)
  6. (utils string)
  7. (utils display)
  8. (prediction)
  9. ;; ice-9 format for format procedure
  10. (ice-9 format))
  11. (define FILE-PATH
  12. "data_banknote_authentication.csv")
  13. ;; For each column we define a column converter, which converts the string,
  14. ;; which is read in from the CSV, to an appropriate data type for the data set
  15. ;; in the program.
  16. (define COLUMN-CONVERTERS
  17. (list (list string->number)
  18. (list string->number)
  19. (list string->number)
  20. (list string->number)
  21. (list
  22. #;(lambda (val)
  23. (display (simple-format #f "converting: ~a\n" val))
  24. (display (simple-format #f "converted: ~a\n" (string->number val)))
  25. (string->number val))
  26. (lambda (val) (string->number (string-trim-both val))))))
  27. ;; Using the defined column converters, we define the data set.
  28. (define banking-dataset
  29. (all-rows "data_banknote_authentication.csv" #:converters COLUMN-CONVERTERS))
  30. ;; This is an artefact from development. It serves as an example to test things
  31. ;; with interactively or in a shorter time than with a whole larger data set.
  32. (define dev-dataset
  33. (list #(2.771244718 1.784783929 0)
  34. #(1.728571309 1.169761413 0)
  35. #(3.678319846 2.81281357 0)
  36. #(3.961043357 2.61995032 0)
  37. #(2.999208922 2.209014212 0)
  38. #(7.497545867 3.162953546 1)
  39. #(9.00220326 3.339047188 1)
  40. #(7.444542326 0.476683375 1)
  41. #(10.12493903 3.234550982 1)
  42. #(6.642287351 3.319983761 1)))
  43. ;; displays a string representation of a learned decision tree
  44. (define-public print-tree
  45. (lambda (tree label-column-index)
  46. (define tree->string
  47. (lambda (tree depth)
  48. (cond
  49. [(leaf-node? tree)
  50. (string-append (n-times-string " " depth)
  51. "["
  52. (number->string
  53. (dataset-majority-prediction (node-data tree)
  54. label-column-index))
  55. "]\n")]
  56. [else
  57. (string-append
  58. (string-append (n-times-string " " depth)
  59. "[feature:"
  60. (number->string (node-split-feature-index tree))
  61. " < "
  62. (number->string (node-split-value tree))
  63. "]\n")
  64. (tree->string (node-left tree) (+ depth 1))
  65. (tree->string (node-right tree) (+ depth 1)))])))
  66. (displayln (tree->string tree 0))))
  67. (display
  68. (simple-format
  69. #f "~a\n"
  70. (map (lambda (num) (format #f "~,3f\n" num))
  71. (evaluate-algorithm #:dataset (shuffle-dataset banking-dataset #:seed 12345)
  72. #:n-folds 10
  73. #:feature-column-indices '(0 1 2 3)
  74. #:label-column-index 4
  75. #:max-depth 6
  76. #:min-data-points 12
  77. #:min-data-points-ratio 0.02
  78. #:min-impurity-split (expt 10 -7)
  79. #:stop-at-no-impurity-improvement #t
  80. #:random-seed 12345))))
  81. ;; (define tree
  82. ;; (fit #:train-data (shuffle-dataset banking-dataset #:seed 12345)
  83. ;; #:feature-column-indices (list 0 1 2 3)
  84. ;; #:label-column-index 4
  85. ;; #:max-depth 5
  86. ;; #:min-data-points 12
  87. ;; #:min-data-points-ratio 0.02
  88. ;; #:min-impurity-split (expt 10 -7)
  89. ;; #:stop-at-no-impurity-improvement #t))
  90. ;; (print-tree tree 4)