problem.rb 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # This file is part of svm_toolkit.
  2. #
  3. # Author:: Peter Lane
  4. # Copyright:: Copyright 2011-13, Peter Lane.
  5. # License:: MIT License
  6. #
  7. module SvmToolkit
  8. class Problem
  9. #
  10. # Support constructing a problem from arrays of double values.
  11. # * Input
  12. # [+instances+] an array of instances, each instance being an array of doubles.
  13. # [+labels+] an array of doubles, forming the labels for each instance.
  14. #
  15. # An ArgumentError exception is raised if all the following conditions are not met:
  16. # * the number of instances should equal the number of labels,
  17. # * there must be at least one instance, and
  18. # * every instance must have the same number of features.
  19. #
  20. def Problem.from_array(instances, labels)
  21. unless instances.size == labels.size
  22. raise ArgumentError.new "Number of instances must equal number of labels"
  23. end
  24. unless instances.size > 0
  25. raise ArgumentError.new "There must be at least one instance."
  26. end
  27. unless instances.collect {|i| i.size}.min == instances.collect {|i| i.size}.max
  28. raise ArgumentError.new "All instances must have the same size"
  29. end
  30. problem = Problem.new
  31. problem.l = labels.size
  32. # -- add in the training data
  33. problem.x = Node[instances.size, instances[0].size].new
  34. instances.each_with_index do |instance, i|
  35. instance.each_with_index do |v, j|
  36. problem.x[i][j] = Node.new(j, v)
  37. end
  38. end
  39. # -- add in the labels
  40. problem.y = Java::double[labels.size].new
  41. labels.each_with_index do |v, i|
  42. problem.y[i] = v
  43. end
  44. return problem
  45. end
  46. # To select SvmLight input file format
  47. SvmLight = 0
  48. # To select Csv input file format
  49. Csv = 1
  50. # To select ARFF input file format
  51. Arff = 2
  52. #
  53. # Read in a problem definition from a file.
  54. # Input:
  55. # * +filename+, the name of the file
  56. # * +format+, either Svm::SvmLight (default), Svm::Csv or Svm::Arff
  57. # Raises ArgumentError if there is any error in format.
  58. #
  59. def Problem.from_file(filename, format = SvmLight)
  60. case format
  61. when SvmLight
  62. return Problem.from_file_svmlight filename
  63. when Csv
  64. return Problem.from_file_csv filename
  65. when Arff
  66. return Problem.from_file_arff filename
  67. end
  68. end
  69. #
  70. # Read in a problem definition in svmlight format from given
  71. # filename.
  72. # Raises ArgumentError if there is any error in format.
  73. #
  74. def Problem.from_file_svmlight filename
  75. instances = []
  76. labels = []
  77. max_index = 0
  78. IO.foreach(filename) do |line|
  79. tokens = line.split(" ")
  80. labels << tokens[0].to_f
  81. instance = []
  82. tokens[1..-1].each do |feature|
  83. index, value = feature.split(":")
  84. instance << Node.new(index.to_i, value.to_f)
  85. max_index = [index.to_i, max_index].max
  86. end
  87. instances << instance
  88. end
  89. max_index += 1 # to allow for 0 position
  90. unless instances.size == labels.size
  91. raise ArgumentError.new "Number of labels read differs from number of instances"
  92. end
  93. # now create a Problem definition
  94. problem = Problem.new
  95. problem.l = instances.size
  96. # -- add in the training data
  97. problem.x = Node[instances.size, max_index].new
  98. # -- fill with blank nodes
  99. instances.size.times do |i|
  100. max_index.times do |j|
  101. problem.x[i][j] = Node.new(i, 0)
  102. end
  103. end
  104. # -- add known values
  105. instances.each_with_index do |instance, i|
  106. instance.each do |node|
  107. problem.x[i][node.index] = node
  108. end
  109. end
  110. # -- add in the labels
  111. problem.y = Java::double[labels.size].new
  112. labels.each_with_index do |v, i|
  113. problem.y[i] = v
  114. end
  115. return problem
  116. end
  117. #
  118. # Read in a problem definition in csv format from given
  119. # filename.
  120. # Raises ArgumentError if there is any error in format.
  121. #
  122. def Problem.from_file_csv filename
  123. instances = []
  124. labels = []
  125. max_index = 0
  126. IO.foreach(filename) do |line|
  127. tokens = line.split(",")
  128. labels << tokens[0].to_f
  129. instance = []
  130. tokens[1..-1].each_with_index do |value, index|
  131. instance << Node.new(index, value.to_f)
  132. end
  133. max_index = [tokens.size, max_index].max
  134. instances << instance
  135. end
  136. max_index += 1 # to allow for 0 position
  137. unless instances.size == labels.size
  138. raise ArgumentError.new "Number of labels read differs from number of instances"
  139. end
  140. # now create a Problem definition
  141. problem = Problem.new
  142. problem.l = instances.size
  143. # -- add in the training data
  144. problem.x = Node[instances.size, max_index].new
  145. # -- fill with blank nodes
  146. instances.size.times do |i|
  147. max_index.times do |j|
  148. problem.x[i][j] = Node.new(i, 0)
  149. end
  150. end
  151. # -- add known values
  152. instances.each_with_index do |instance, i|
  153. instance.each do |node|
  154. problem.x[i][node.index] = node
  155. end
  156. end
  157. # -- add in the labels
  158. problem.y = Java::double[labels.size].new
  159. labels.each_with_index do |v, i|
  160. problem.y[i] = v
  161. end
  162. return problem
  163. end
  164. #
  165. # Read in a problem definition in arff format from given
  166. # filename.
  167. # Assumes all values are numbers (non-numbers converted to 0.0),
  168. # and that the class is the last field.
  169. # Raises ArgumentError if there is any error in format.
  170. #
  171. def Problem.from_file_arff filename
  172. instances = []
  173. labels = []
  174. max_index = 0
  175. found_data = false
  176. IO.foreach(filename) do |line|
  177. unless found_data
  178. puts "Ignoring", line
  179. found_data = line.downcase.strip == "@data"
  180. next # repeat the loop
  181. end
  182. tokens = line.split(",")
  183. labels << tokens.last.to_f
  184. instance = []
  185. tokens[1...-1].each_with_index do |value, index|
  186. instance << Node.new(index, value.to_f)
  187. end
  188. max_index = [tokens.size, max_index].max
  189. instances << instance
  190. end
  191. max_index += 1 # to allow for 0 position
  192. unless instances.size == labels.size
  193. raise ArgumentError.new "Number of labels read differs from number of instances"
  194. end
  195. # now create a Problem definition
  196. problem = Problem.new
  197. problem.l = instances.size
  198. # -- add in the training data
  199. problem.x = Node[instances.size, max_index].new
  200. # -- fill with blank nodes
  201. instances.size.times do |i|
  202. max_index.times do |j|
  203. problem.x[i][j] = Node.new(i, 0)
  204. end
  205. end
  206. # -- add known values
  207. instances.each_with_index do |instance, i|
  208. instance.each do |node|
  209. problem.x[i][node.index] = node
  210. end
  211. end
  212. # -- add in the labels
  213. problem.y = Java::double[labels.size].new
  214. labels.each_with_index do |v, i|
  215. problem.y[i] = v
  216. end
  217. return problem
  218. end
  219. # Return the number of instances
  220. def size
  221. self.l
  222. end
  223. # Rescale values within problem to be in range min_value to max_value
  224. #
  225. # For SVM models, it is recommended all features be in range [0,1] or [-1,1]
  226. def rescale(min_value = 0.0, max_value = 1.0)
  227. return if self.l.zero?
  228. x[0].size.times do |i|
  229. rescale_column(i, min_value, max_value)
  230. end
  231. end
  232. # Create a new problem by combining the instances in this problem with
  233. # those in the given problem.
  234. def merge problem
  235. unless self.x[0].size == problem.x[0].size
  236. raise ArgumentError.new "Cannot merge two problems with different numbers of features"
  237. end
  238. num_features = self.x[0].size
  239. num_instances = size + problem.size
  240. new_problem = Problem.new
  241. new_problem.l = num_instances
  242. new_problem.x = Node[num_instances, num_features].new
  243. new_problem.y = Java::double[num_instances].new
  244. # fill out the features
  245. num_instances.times do |i|
  246. num_features.times do |j|
  247. if i < size
  248. new_problem.x[i][j] = self.x[i][j]
  249. else
  250. new_problem.x[i][j] = problem.x[i-size][j]
  251. end
  252. end
  253. end
  254. # fill out the labels
  255. num_instances.times do |i|
  256. if i < size
  257. new_problem.y[i] = self.y[i]
  258. else
  259. new_problem.y[i] = problem.y[i-size]
  260. end
  261. end
  262. return new_problem
  263. end
  264. # Rescale values within problem for given column index,
  265. # to be in range min_value to max_value
  266. private
  267. def rescale_column(col, min_value, max_value)
  268. # -- first locate the column's range
  269. current_min = x[0][col].value
  270. current_max = x[0][col].value
  271. self.l.times do |index|
  272. if x[index][col].value < current_min
  273. current_min = x[index][col].value
  274. end
  275. if x[index][col].value > current_max
  276. current_max = x[index][col].value
  277. end
  278. end
  279. # -- then update each value
  280. self.l.times do |index|
  281. x[index][col].value = ((max_value - min_value) * (x[index][col].value - current_min) / (current_max - current_min)) + min_value
  282. end
  283. end
  284. end
  285. end