model.rb 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. module SvmToolkit
  2. # Extends the Java Model class with some additional methods.
  3. #
  4. class Model
  5. # Evaluate model on given data set (an instance of Problem),
  6. # returning the number of errors made.
  7. # Optional parameters include:
  8. # * :evaluator => Evaluator::OverallAccuracy, the name of the class to use for computing performance
  9. # * :print_results => false, whether to print the result for each instance
  10. def evaluate_dataset(data, params = {})
  11. evaluator = params.fetch(:evaluator, Evaluator::OverallAccuracy)
  12. print_results = params.fetch(:print_results, false)
  13. performance = evaluator.new
  14. data.l.times do |i|
  15. pred = Svm.svm_predict(self, data.x[i])
  16. performance.add_result(data.y[i], pred)
  17. if print_results
  18. puts "Instance #{i}, Prediction: #{pred}, True label: #{data.y[i]}"
  19. end
  20. end
  21. return performance
  22. end
  23. # Return the value of w squared for the hyperplane.
  24. # -- returned as an array if there is not just one value.
  25. def w_squared
  26. if self.w_2.size == 1
  27. self.w_2[0]
  28. else
  29. self.w_2.to_a
  30. end
  31. end
  32. # Return an array of indices of the training instances used as
  33. # support vectors.
  34. def support_vector_indices
  35. result = []
  36. unless sv_indices.nil?
  37. sv_indices.size.times do |i|
  38. result << sv_indices[i]
  39. end
  40. end
  41. return result
  42. end
  43. # Return the SVM problem type for this model
  44. def svm_type
  45. self.param.svm_type
  46. end
  47. # Return the kernel type for this model
  48. def kernel_type
  49. self.param.kernel_type
  50. end
  51. # Return the value of the degree parameter
  52. def degree
  53. self.param.degree
  54. end
  55. # Return the value of the gamma parameter
  56. def gamma
  57. self.param.gamma
  58. end
  59. # Return the value of the cost parameter
  60. def cost
  61. self.param.cost
  62. end
  63. # Return the number of classes handled by this model.
  64. def number_classes
  65. self.nr_class
  66. end
  67. # Save model to given filename.
  68. # Raises IOError on any error.
  69. def save filename
  70. begin
  71. Svm.svm_save_model(filename, self)
  72. rescue java.io.IOException
  73. raise IOError.new "Error in saving SVM model to file"
  74. end
  75. end
  76. # Load model from given filename.
  77. # Raises IOError on any error.
  78. def self.load filename
  79. begin
  80. Svm.svm_load_model(filename)
  81. rescue java.io.IOException
  82. raise IOError.new "Error in loading SVM model from file"
  83. end
  84. end
  85. #
  86. # Predict the class of given instance number in given problem.
  87. #
  88. def predict(problem, instance_number)
  89. Svm.svm_predict(self, problem.x[instance_number])
  90. end
  91. #
  92. # Return the values of given instance number of given problem against
  93. # each decision boundary.
  94. # (This is the distance of the instance from each boundary.)
  95. #
  96. # Return value is an array if more than one decision boundary.
  97. #
  98. def predict_values(problem, instance_number)
  99. dist = Array.new(number_classes*(number_classes-1)/2, 0).to_java(:double)
  100. Svm.svm_predict_values(self, problem.x[instance_number], dist)
  101. if dist.size == 1
  102. return dist[0]
  103. else
  104. return dist.to_a
  105. end
  106. end
  107. end
  108. end