problem.rb 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. module SvmToolkit
  2. # Extends the Java Problem class with some additional features.
  3. #
  4. class Problem
  5. # Support constructing a problem from arrays of double values.
  6. #
  7. # [+instances+]:: an array of instances, each instance being an array of doubles.
  8. # [+labels+]:: an array of doubles, forming the labels for each instance.
  9. #
  10. # An ArgumentError exception is raised if all the following conditions are not met:
  11. # * the number of instances should equal the number of labels,
  12. # * there must be at least one instance, and
  13. # * every instance must have the same number of features.
  14. #
  15. def Problem.from_array(instances, labels)
  16. unless instances.size == labels.size
  17. raise ArgumentError.new "Number of instances must equal number of labels"
  18. end
  19. unless instances.size > 0
  20. raise ArgumentError.new "There must be at least one instance."
  21. end
  22. unless instances.collect {|i| i.size}.min == instances.collect {|i| i.size}.max
  23. raise ArgumentError.new "All instances must have the same size"
  24. end
  25. problem = Problem.new
  26. problem.l = labels.size
  27. # -- add in the training data
  28. problem.x = Node[instances.size, instances[0].size].new
  29. instances.each_with_index do |instance, i|
  30. instance.each_with_index do |v, j|
  31. problem.x[i][j] = Node.new(j, v)
  32. end
  33. end
  34. # -- add in the labels
  35. problem.y = Java::double[labels.size].new
  36. labels.each_with_index do |v, i|
  37. problem.y[i] = v
  38. end
  39. return problem
  40. end
  41. # To select SvmLight input file format
  42. SvmLight = 0
  43. # To select Csv input file format
  44. Csv = 1
  45. # To select ARFF input file format
  46. Arff = 2
  47. #
  48. # Read in a problem definition from a file.
  49. # Input:
  50. # * +filename+, the name of the file
  51. # * +format+, either Svm::SvmLight (default), Svm::Csv or Svm::Arff
  52. # Raises ArgumentError if there is any error in format.
  53. #
  54. def Problem.from_file(filename, format = SvmLight)
  55. case format
  56. when SvmLight
  57. return Problem.from_file_svmlight filename
  58. when Csv
  59. return Problem.from_file_csv filename
  60. when Arff
  61. return Problem.from_file_arff filename
  62. end
  63. end
  64. #
  65. # Read in a problem definition in svmlight format from given
  66. # filename.
  67. # Raises ArgumentError if there is any error in format.
  68. #
  69. def Problem.from_file_svmlight filename
  70. instances = []
  71. labels = []
  72. max_index = 0
  73. IO.foreach(filename) do |line|
  74. tokens = line.split(" ")
  75. labels << tokens[0].to_f
  76. instance = []
  77. tokens[1..-1].each do |feature|
  78. index, value = feature.split(":")
  79. instance << Node.new(index.to_i, value.to_f)
  80. max_index = [index.to_i, max_index].max
  81. end
  82. instances << instance
  83. end
  84. max_index += 1 # to allow for 0 position
  85. unless instances.size == labels.size
  86. raise ArgumentError.new "Number of labels read differs from number of instances"
  87. end
  88. # now create a Problem definition
  89. problem = Problem.new
  90. problem.l = instances.size
  91. # -- add in the training data
  92. problem.x = Node[instances.size, max_index].new
  93. # -- fill with blank nodes
  94. instances.size.times do |i|
  95. max_index.times do |j|
  96. problem.x[i][j] = Node.new(i, 0)
  97. end
  98. end
  99. # -- add known values
  100. instances.each_with_index do |instance, i|
  101. instance.each do |node|
  102. problem.x[i][node.index] = node
  103. end
  104. end
  105. # -- add in the labels
  106. problem.y = Java::double[labels.size].new
  107. labels.each_with_index do |v, i|
  108. problem.y[i] = v
  109. end
  110. return problem
  111. end
  112. #
  113. # Read in a problem definition in csv format from given
  114. # filename.
  115. # Raises ArgumentError if there is any error in format.
  116. #
  117. def Problem.from_file_csv filename
  118. instances = []
  119. labels = []
  120. max_index = 0
  121. IO.foreach(filename) do |line|
  122. tokens = line.split(",")
  123. labels << tokens[0].to_f
  124. instance = []
  125. tokens[1..-1].each_with_index do |value, index|
  126. instance << Node.new(index, value.to_f)
  127. end
  128. max_index = [tokens.size, max_index].max
  129. instances << instance
  130. end
  131. max_index += 1 # to allow for 0 position
  132. unless instances.size == labels.size
  133. raise ArgumentError.new "Number of labels read differs from number of instances"
  134. end
  135. # now create a Problem definition
  136. problem = Problem.new
  137. problem.l = instances.size
  138. # -- add in the training data
  139. problem.x = Node[instances.size, max_index].new
  140. # -- fill with blank nodes
  141. instances.size.times do |i|
  142. max_index.times do |j|
  143. problem.x[i][j] = Node.new(i, 0)
  144. end
  145. end
  146. # -- add known values
  147. instances.each_with_index do |instance, i|
  148. instance.each do |node|
  149. problem.x[i][node.index] = node
  150. end
  151. end
  152. # -- add in the labels
  153. problem.y = Java::double[labels.size].new
  154. labels.each_with_index do |v, i|
  155. problem.y[i] = v
  156. end
  157. return problem
  158. end
  159. #
  160. # Read in a problem definition in arff format from given
  161. # filename.
  162. # Assumes all values are numbers (non-numbers converted to 0.0),
  163. # and that the class is the last field.
  164. # Raises ArgumentError if there is any error in format.
  165. #
  166. def Problem.from_file_arff filename
  167. instances = []
  168. labels = []
  169. max_index = 0
  170. found_data = false
  171. IO.foreach(filename) do |line|
  172. unless found_data
  173. puts "Ignoring", line
  174. found_data = line.downcase.strip == "@data"
  175. next # repeat the loop
  176. end
  177. tokens = line.split(",")
  178. labels << tokens.last.to_f
  179. instance = []
  180. tokens[1...-1].each_with_index do |value, index|
  181. instance << Node.new(index, value.to_f)
  182. end
  183. max_index = [tokens.size, max_index].max
  184. instances << instance
  185. end
  186. max_index += 1 # to allow for 0 position
  187. unless instances.size == labels.size
  188. raise ArgumentError.new "Number of labels read differs from number of instances"
  189. end
  190. # now create a Problem definition
  191. problem = Problem.new
  192. problem.l = instances.size
  193. # -- add in the training data
  194. problem.x = Node[instances.size, max_index].new
  195. # -- fill with blank nodes
  196. instances.size.times do |i|
  197. max_index.times do |j|
  198. problem.x[i][j] = Node.new(i, 0)
  199. end
  200. end
  201. # -- add known values
  202. instances.each_with_index do |instance, i|
  203. instance.each do |node|
  204. problem.x[i][node.index] = node
  205. end
  206. end
  207. # -- add in the labels
  208. problem.y = Java::double[labels.size].new
  209. labels.each_with_index do |v, i|
  210. problem.y[i] = v
  211. end
  212. return problem
  213. end
  214. # Return the number of instances
  215. def size
  216. self.l
  217. end
  218. # Rescale values within problem to be in range min_value to max_value
  219. #
  220. # For SVM models, it is recommended all features be in range [0,1] or [-1,1]
  221. def rescale(min_value = 0.0, max_value = 1.0)
  222. return if self.l.zero?
  223. x[0].size.times do |i|
  224. rescale_column(i, min_value, max_value)
  225. end
  226. end
  227. # Create a new problem by combining the instances in this problem with
  228. # those in the given problem.
  229. def merge problem
  230. unless self.x[0].size == problem.x[0].size
  231. raise ArgumentError.new "Cannot merge two problems with different numbers of features"
  232. end
  233. num_features = self.x[0].size
  234. num_instances = size + problem.size
  235. new_problem = Problem.new
  236. new_problem.l = num_instances
  237. new_problem.x = Node[num_instances, num_features].new
  238. new_problem.y = Java::double[num_instances].new
  239. # fill out the features
  240. num_instances.times do |i|
  241. num_features.times do |j|
  242. if i < size
  243. new_problem.x[i][j] = self.x[i][j]
  244. else
  245. new_problem.x[i][j] = problem.x[i-size][j]
  246. end
  247. end
  248. end
  249. # fill out the labels
  250. num_instances.times do |i|
  251. if i < size
  252. new_problem.y[i] = self.y[i]
  253. else
  254. new_problem.y[i] = problem.y[i-size]
  255. end
  256. end
  257. return new_problem
  258. end
  259. # Rescale values within problem for given column index,
  260. # to be in range min_value to max_value
  261. private
  262. def rescale_column(col, min_value, max_value)
  263. # -- first locate the column's range
  264. current_min = x[0][col].value
  265. current_max = x[0][col].value
  266. self.l.times do |index|
  267. if x[index][col].value < current_min
  268. current_min = x[index][col].value
  269. end
  270. if x[index][col].value > current_max
  271. current_max = x[index][col].value
  272. end
  273. end
  274. # -- then update each value
  275. self.l.times do |index|
  276. x[index][col].value = ((max_value - min_value) * (x[index][col].value - current_min) / (current_max - current_min)) + min_value
  277. end
  278. end
  279. end
  280. end