svm.rb 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. module SvmToolkit
  2. # Extends the Java SVM class
  3. #
  4. # Available methods include:
  5. #
  6. # Svm.svm_train(problem, param)
  7. #
  8. # problem:: instance of Problem
  9. # param:: instance of Parameter
  10. #
  11. # Returns an instance of Model
  12. #
  13. # Svm.svm_cross_validation(problem, param, nr_folds, target)
  14. #
  15. # problem:: instance of Problem
  16. # param:: instance of Parameter
  17. # nr_fold:: number of folds
  18. # target:: resulting predictions in an array
  19. #
  20. class Svm
  21. # Perform cross validation search on given gamma/cost values,
  22. # using an RBF kernel,
  23. # returning the best performing model and optionally displaying
  24. # a contour map of performance.
  25. #
  26. # training_set:: instance of Problem, used for training
  27. # cross_valn_set:: instance of Problem, used for evaluating models
  28. # costs:: array of cost values to search across
  29. # gammas:: array of gamma values to search across
  30. # params:: Optional parameters include:
  31. # * :evaluator => Evaluator::OverallAccuracy, the name of the class
  32. # to use for computing performance
  33. # * :show_plot => false, whether to display contour plot
  34. #
  35. # Returns an instance of Model, the best performing model.
  36. #
  37. def Svm.cross_validation_search(training_set, cross_valn_set,
  38. costs = [-2,-1,0,1,2,3].collect {|i| 2**i},
  39. gammas = [-2,-1,0,1,2,3].collect {|i| 2**i},
  40. params = {})
  41. evaluator = params.fetch :evaluator, Evaluator::OverallAccuracy
  42. show_plot = params.fetch :show_plot, false
  43. fjp = ForkJoinPool.new
  44. task = CrossValidationSearch.new gammas, costs, training_set, cross_valn_set, evaluator
  45. results, best_model = fjp.invoke task
  46. if show_plot
  47. ContourDisplay.new(costs.collect {|n| Math.log2(n)},
  48. gammas.collect {|n| Math.log2(n)},
  49. results)
  50. end
  51. return best_model
  52. end
  53. private
  54. # Set up the cross validation search across a cost/gamma pair
  55. class CrossValidationSearch < RecursiveTask
  56. def initialize gammas, costs, training_set, cross_valn_set, evaluator
  57. super()
  58. @gammas = gammas
  59. @costs = costs
  60. @training_set = training_set
  61. @cross_valn_set = cross_valn_set
  62. @evaluator = evaluator
  63. end
  64. # perform actual computation, return results/best_model
  65. def compute
  66. tasks = []
  67. # create one task per gamma/cost pair
  68. @gammas.each do |gamma|
  69. @costs.each do |cost|
  70. tasks << SvmTrainer.new(@training_set, Parameter.new(
  71. :svm_type => Parameter::C_SVC,
  72. :kernel_type => Parameter::RBF,
  73. :cost => cost,
  74. :gamma => gamma
  75. ), @cross_valn_set, @evaluator)
  76. end
  77. end
  78. # set off all the tasks
  79. tasks.each do |task|
  80. task.fork
  81. end
  82. # collect the results
  83. results = []
  84. best_model = nil
  85. lowest_error = nil
  86. @gammas.each do |gamma|
  87. results_row = []
  88. @costs.each do |cost|
  89. task = tasks.shift
  90. model, result = task.join
  91. if result.better_than? lowest_error
  92. best_model = model
  93. lowest_error = result
  94. end
  95. puts "Result for cost = #{cost} gamma = #{gamma} is #{result.value}"
  96. results_row << result.value
  97. end
  98. results << results_row
  99. end
  100. return results, best_model
  101. end
  102. end
  103. # Represent a single training task for an SVM RBF model
  104. class SvmTrainer < RecursiveTask
  105. def initialize training_set, parameters, cross_valn_set, evaluator
  106. super()
  107. @training_set = training_set
  108. @parameters = parameters
  109. @cross_valn_set = cross_valn_set
  110. @evaluator = evaluator
  111. end
  112. def compute
  113. model = Svm.svm_train @training_set, @parameters
  114. result = model.evaluate_dataset @cross_valn_set, :evaluator => @evaluator
  115. return model, result
  116. end
  117. end
  118. class ContourDisplay < javax.swing.JFrame
  119. def initialize(xs, ys, zs)
  120. super("Cross-Validation Performance")
  121. self.setSize(500, 400)
  122. cxs = Java::double[][ys.size].new
  123. cys = Java::double[][ys.size].new
  124. ys.size.times do |i|
  125. cxs[i] = Java::double[xs.size].new
  126. cys[i] = Java::double[xs.size].new
  127. xs.size.times do |j|
  128. cxs[i][j] = xs[j]
  129. cys[i][j] = ys[i]
  130. end
  131. end
  132. czs = Java::double[][ys.size].new
  133. ys.size.times do |i|
  134. czs[i] = Java::double[xs.size].new
  135. xs.size.times do |j|
  136. czs[i][j] = zs[i][j]
  137. end
  138. end
  139. plot = ContourPlot.new(
  140. cxs,
  141. cys,
  142. czs,
  143. 10,
  144. false,
  145. "",
  146. "Cost (log-scale)",
  147. "Gamma (log-scale)",
  148. nil,
  149. nil
  150. )
  151. plot.colorizeContours(java.awt::Color.green, java.awt::Color.red)
  152. symbol = DiamondSymbol.new
  153. symbol.border_color = java.awt::Color.blue
  154. symbol.fill_color = java.awt::Color.blue
  155. symbol.size = 4
  156. run = PlotRun.new
  157. ys.size.times do |i|
  158. xs.size.times do |j|
  159. run.add(PlotDatum.new(cxs[i][j], cys[i][j], false, symbol))
  160. end
  161. end
  162. plot.runs << run
  163. panel = PlotPanel.new(plot)
  164. panel.background = java.awt::Color.white
  165. add panel
  166. self.setDefaultCloseOperation(javax.swing.WindowConstants::DISPOSE_ON_CLOSE)
  167. self.visible = true
  168. end
  169. end
  170. end
  171. end