tree.scm 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. (define-module (tree)
  2. ;; The exported procedures are created from the macro which defines record
  3. ;; types, where one cannot use define-public.
  4. #:export (Split
  5. make-split
  6. split?
  7. split-feature-index
  8. split-value
  9. set-split-value
  10. split-subsets
  11. set-split-subsets
  12. split-cost
  13. set-split-cost
  14. Node
  15. make-node
  16. node?
  17. node-data
  18. set-node-data
  19. node-split-feature-index
  20. set-node-split-feature-index
  21. node-split-value
  22. set-node-split-value
  23. node-left
  24. set-node-left
  25. node-right
  26. set-node-right))
  27. (use-modules
  28. ;; SRFI-9 for standardized records
  29. (srfi srfi-9)
  30. ;; for functional structs (not part of srfi-9 directly)
  31. (srfi srfi-9 gnu)
  32. (utils display-utils)
  33. (utils string-utils)
  34. (dataset))
  35. ;; ===============
  36. ;; DATA STRUCTURES
  37. ;; ===============
  38. ;; A Split record is simply a structure to contain all information about a
  39. ;; split. It is not a node containing data.
  40. (define-immutable-record-type Split
  41. ;; define constructor
  42. (make-split index value subsets cost)
  43. ;; define predicate
  44. split?
  45. ;; define accessors and functional setters
  46. ;; Note: It should never be required to update the index of the feature, which
  47. ;; was used to create the split. If that feature changed, the whole split
  48. ;; record would become invalid and a new instance should be created.
  49. (index split-feature-index)
  50. (value split-value set-split-value)
  51. (subsets split-subsets set-split-subsets)
  52. (cost split-cost set-split-cost))
  53. ;; A Node record is a node of the decision tree. It contains data and
  54. ;; information about the feature, at which it splits and the value for that
  55. ;; split. It also contains references to the child nodes.
  56. (define-immutable-record-type Node
  57. ;; define constructor
  58. (make-node data split-feature-index split-value left right)
  59. ;; define predicate
  60. node?
  61. ;; define accessors and functional setters
  62. (data node-data set-node-data)
  63. (split-feature-index
  64. node-split-feature-index
  65. set-node-split-feature-index)
  66. (split-value
  67. node-split-value
  68. set-node-split-value)
  69. (left node-left set-node-left)
  70. (right node-right set-node-right))
  71. (define-public make-leaf-node
  72. (lambda (data)
  73. (make-node data
  74. 'none
  75. 'none
  76. empty-dataset
  77. empty-dataset)))
  78. (define-public make-leaf-node-from-split-node
  79. (lambda (split-node)
  80. (make-leaf-node (node-data split-node))))
  81. (define-public leaf-node?
  82. (lambda (node)
  83. (and (dataset-empty? (node-left node))
  84. (dataset-empty? (node-right node)))))
  85. (define-public last-split-node?
  86. (lambda (node)
  87. (cond [(leaf-node? node) #f]
  88. [else
  89. (and (leaf-node? (node-left node))
  90. (leaf-node? (node-right node)))])))
  91. ;; displays a string representation of a learned decision tree
  92. (define-public print-tree
  93. (lambda (tree label-column-index)
  94. (define tree->string
  95. (lambda (tree depth)
  96. (cond
  97. [(leaf-node? tree)
  98. (string-append (n-times-string " " depth)
  99. "["
  100. (number->string
  101. (dataset-majority-prediction (node-data tree)
  102. label-column-index))
  103. "]\n")]
  104. [else
  105. (string-append
  106. (string-append (n-times-string " " depth)
  107. "[feature:"
  108. (number->string (node-split-feature-index tree))
  109. " < "
  110. (number->string (node-split-value tree))
  111. "]\n")
  112. (tree->string (node-left tree) (+ depth 1))
  113. (tree->string (node-right tree) (+ depth 1)))])))
  114. (displayln (tree->string tree 0))))