123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- module SvmToolkit
- # Extends the Java SVM class
- #
- # Available methods include:
- #
- # Svm.svm_train(problem, param)
- #
- # problem:: instance of Problem
- # param:: instance of Parameter
- #
- # Returns an instance of Model
- #
- # Svm.svm_cross_validation(problem, param, nr_folds, target)
- #
- # problem:: instance of Problem
- # param:: instance of Parameter
- # nr_fold:: number of folds
- # target:: resulting predictions in an array
- #
- class Svm
- # Perform cross validation search on given gamma/cost values,
- # using an RBF kernel,
- # returning the best performing model and optionally displaying
- # a contour map of performance.
- #
- # training_set:: instance of Problem, used for training
- # cross_valn_set:: instance of Problem, used for evaluating models
- # costs:: array of cost values to search across
- # gammas:: array of gamma values to search across
- # params:: Optional parameters include:
- # * :evaluator => Evaluator::OverallAccuracy, the name of the class
- # to use for computing performance
- # * :show_plot => false, whether to display contour plot
- #
- # Returns an instance of Model, the best performing model.
- #
- def Svm.cross_validation_search(training_set, cross_valn_set,
- costs = [-2,-1,0,1,2,3].collect {|i| 2**i},
- gammas = [-2,-1,0,1,2,3].collect {|i| 2**i},
- params = {})
- evaluator = params.fetch :evaluator, Evaluator::OverallAccuracy
- show_plot = params.fetch :show_plot, false
- fjp = ForkJoinPool.new
- task = CrossValidationSearch.new gammas, costs, training_set, cross_valn_set, evaluator
- results, best_model = fjp.invoke task
- if show_plot
- ContourDisplay.new(costs.collect {|n| Math.log2(n)},
- gammas.collect {|n| Math.log2(n)},
- results)
- end
- return best_model
- end
- private
- # Set up the cross validation search across a cost/gamma pair
- class CrossValidationSearch < RecursiveTask
- def initialize gammas, costs, training_set, cross_valn_set, evaluator
- super()
- @gammas = gammas
- @costs = costs
- @training_set = training_set
- @cross_valn_set = cross_valn_set
- @evaluator = evaluator
- end
- # perform actual computation, return results/best_model
- def compute
- tasks = []
- # create one task per gamma/cost pair
- @gammas.each do |gamma|
- @costs.each do |cost|
- tasks << SvmTrainer.new(@training_set, Parameter.new(
- :svm_type => Parameter::C_SVC,
- :kernel_type => Parameter::RBF,
- :cost => cost,
- :gamma => gamma
- ), @cross_valn_set, @evaluator)
- end
- end
- # set off all the tasks
- tasks.each do |task|
- task.fork
- end
- # collect the results
- results = []
- best_model = nil
- lowest_error = nil
- @gammas.each do |gamma|
- results_row = []
- @costs.each do |cost|
- task = tasks.shift
- model, result = task.join
- if result.better_than? lowest_error
- best_model = model
- lowest_error = result
- end
- puts "Result for cost = #{cost} gamma = #{gamma} is #{result.value}"
- results_row << result.value
- end
- results << results_row
- end
- return results, best_model
- end
- end
- # Represent a single training task for an SVM RBF model
- class SvmTrainer < RecursiveTask
- def initialize training_set, parameters, cross_valn_set, evaluator
- super()
- @training_set = training_set
- @parameters = parameters
- @cross_valn_set = cross_valn_set
- @evaluator = evaluator
- end
- def compute
- model = Svm.svm_train @training_set, @parameters
- result = model.evaluate_dataset @cross_valn_set, :evaluator => @evaluator
- return model, result
- end
- end
- class ContourDisplay < javax.swing.JFrame
- def initialize(xs, ys, zs)
- super("Cross-Validation Performance")
- self.setSize(500, 400)
- cxs = Java::double[][ys.size].new
- cys = Java::double[][ys.size].new
- ys.size.times do |i|
- cxs[i] = Java::double[xs.size].new
- cys[i] = Java::double[xs.size].new
- xs.size.times do |j|
- cxs[i][j] = xs[j]
- cys[i][j] = ys[i]
- end
- end
- czs = Java::double[][ys.size].new
- ys.size.times do |i|
- czs[i] = Java::double[xs.size].new
- xs.size.times do |j|
- czs[i][j] = zs[i][j]
- end
- end
- plot = ContourPlot.new(
- cxs,
- cys,
- czs,
- 10,
- false,
- "",
- "Cost (log-scale)",
- "Gamma (log-scale)",
- nil,
- nil
- )
- plot.colorizeContours(java.awt::Color.green, java.awt::Color.red)
- symbol = DiamondSymbol.new
- symbol.border_color = java.awt::Color.blue
- symbol.fill_color = java.awt::Color.blue
- symbol.size = 4
- run = PlotRun.new
- ys.size.times do |i|
- xs.size.times do |j|
- run.add(PlotDatum.new(cxs[i][j], cys[i][j], false, symbol))
- end
- end
- plot.runs << run
- panel = PlotPanel.new(plot)
- panel.background = java.awt::Color.white
- add panel
- self.setDefaultCloseOperation(javax.swing.WindowConstants::DISPOSE_ON_CLOSE)
- self.visible = true
- end
- end
- end
- end
|