svm-demo 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. #!/usr/bin/env jruby
  2. #
  3. # A loose equivalent of the svm_toy applet which is
  4. # distributed with libsvm.
  5. #
  6. require "java"
  7. require "svm-toolkit"
  8. include SvmToolkit
  9. ["BorderLayout", "Color",
  10. "Dimension", "GridLayout",
  11. "event.ActionListener", "event.MouseListener"
  12. ].each do |i|
  13. import "java.awt.#{i}"
  14. end
  15. ["Box", "BoxLayout", "JButton", "JComboBox", "JFrame", "JLabel",
  16. "JOptionPane", "JPanel", "JScrollPane", "JSpinner", "JTextField",
  17. "SpinnerNumberModel", "WindowConstants", "border.TitledBorder"
  18. ].each do |i|
  19. import "javax.swing.#{i}"
  20. end
  21. class Display < JPanel
  22. Point = Struct.new(:x, :y, :colour)
  23. class MyMouseListener
  24. include MouseListener
  25. def initialize parent
  26. @parent = parent
  27. end
  28. def mouseEntered e; end
  29. def mouseExited e; end
  30. def mousePressed e; end
  31. def mouseReleased e; end
  32. def mouseClicked e
  33. @parent.clicked(e.x, e.y)
  34. end
  35. end
  36. attr_accessor :colour # current label/colour of point to show
  37. Width = 800
  38. Height = 600
  39. def initialize
  40. super()
  41. self.preferred_size = Dimension.new(Width, Height)
  42. add_mouse_listener MyMouseListener.new(self)
  43. @points = []
  44. @support_vectors = []
  45. @colour = Color.blue
  46. end
  47. def paint g
  48. super
  49. if @buffer.nil?
  50. g.background = Color.lightGray
  51. g.clear_rect(0, 0, Width, Height)
  52. else
  53. g.draw_image(@buffer, 0, 0, self)
  54. end
  55. @support_vectors.each do |point|
  56. g.color = Color.yellow
  57. g.fill_oval(point.x-7, point.y-7, 14, 14)
  58. end
  59. @points.each do |point|
  60. g.color = point.colour
  61. g.fill_oval(point.x-3, point.y-3, 6, 6)
  62. end
  63. end
  64. def clear
  65. @points.clear
  66. @support_vectors.clear
  67. @buffer = nil
  68. repaint
  69. end
  70. def clicked(x, y)
  71. if x < Width and y < Height
  72. @points << Point.new(x, y, @colour)
  73. repaint
  74. end
  75. end
  76. def background_colour prediction
  77. if prediction.zero?
  78. Color.new(100, 200, 100)
  79. else
  80. Color.new(100, 100, 200)
  81. end
  82. end
  83. def train(kernel, cost, gamma, degree)
  84. return if @points.empty?
  85. labels = []
  86. instances = []
  87. @points.each do |point|
  88. if point.colour == Color::blue
  89. labels << 1
  90. else
  91. labels << 0
  92. end
  93. instances << [point.x / Width.to_f, point.y / Height.to_f]
  94. end
  95. problem = Problem.from_array(instances, labels)
  96. param = Parameter.new(
  97. :svm_type => Parameter::C_SVC,
  98. :kernel_type => kernel,
  99. :cost => cost,
  100. :gamma => gamma,
  101. :degree => degree
  102. )
  103. model = Svm.svm_train(problem, param)
  104. buffer = self.create_image(Width, Height)
  105. buffer_gc = buffer.graphics
  106. window_gc = self.graphics
  107. instance = Node[2].new
  108. instance[0] = Node.new(0, 0)
  109. instance[1] = Node.new(1, 0)
  110. Width.times do |i|
  111. if i < 498 # draw a progress line
  112. buffer_gc.color = Color::red
  113. buffer_gc.draw_line(i+1, 0, i+1, Height-1)
  114. window_gc.color = Color::red
  115. window_gc.draw_line(i+1, 0, i+1, Height-1)
  116. end
  117. Height.times do |j|
  118. instance[0].value = i / Width.to_f
  119. instance[1].value = j / Height.to_f
  120. prediction = Svm.svm_predict(model, instance)
  121. buffer_gc.color = background_colour prediction
  122. buffer_gc.draw_line(i, j, i, j)
  123. window_gc.color = background_colour prediction
  124. window_gc.draw_line(i, j, i, j)
  125. end
  126. end
  127. @buffer = buffer
  128. @support_vectors = []
  129. model.support_vector_indices.each do |index|
  130. @support_vectors << @points[index]
  131. end
  132. repaint
  133. end
  134. end
  135. class DemoFrame < JFrame
  136. class LabelListener
  137. include ActionListener
  138. def initialize(display, box)
  139. @display = display
  140. @box = box
  141. end
  142. def actionPerformed e
  143. @display.colour = if @box.selected_item == "blue" then
  144. Color.blue
  145. else
  146. Color.green
  147. end
  148. end
  149. end
  150. class KernelChoiceListener
  151. include ActionListener
  152. def initialize(kernel_choice, gamma_choice, degree_choice)
  153. @kernel_choice = kernel_choice
  154. @gamma_choice = gamma_choice
  155. @degree_choice = degree_choice
  156. end
  157. def actionPerformed e
  158. case @kernel_choice.selected_item
  159. when "linear"
  160. @gamma_choice.enabled = false
  161. @degree_choice.enabled = false
  162. when "polynomial"
  163. @gamma_choice.enabled = false
  164. @degree_choice.enabled = true
  165. when "RBF", "sigmoid"
  166. @gamma_choice.enabled = true
  167. @degree_choice.enabled = false
  168. end
  169. end
  170. end
  171. def initialize
  172. super("Support-Vector Machines: Demonstration")
  173. self.setSize(700, 400)
  174. @display_panel = Display.new
  175. add(JScrollPane.new(@display_panel))
  176. add(createLabelButtons, BorderLayout::NORTH)
  177. add(createTrainButtons, BorderLayout::EAST)
  178. add(createHelpLine, BorderLayout::SOUTH)
  179. self.setDefaultCloseOperation(WindowConstants::DISPOSE_ON_CLOSE)
  180. end
  181. def createHelpLine
  182. JLabel.new(<<-END)
  183. <html><body>
  184. Select a class colour and click on main panel to define instances.<br>
  185. Choose kernel type and parameter settings for training.
  186. </body></html>
  187. END
  188. end
  189. def createLabelButtons
  190. panel = JPanel.new
  191. panel.layout = BorderLayout.new
  192. combo_box = JComboBox.new
  193. ["blue", "green"].each do |item|
  194. combo_box.add_item item
  195. end
  196. combo_box.add_action_listener LabelListener.new(@display_panel, combo_box)
  197. clear_button = JButton.new "clear"
  198. clear_button.add_action_listener do
  199. @display_panel.clear
  200. end
  201. @message = JLabel.new
  202. pane = JPanel.new
  203. pane.add JLabel.new("Class:")
  204. pane.add combo_box
  205. pane.add clear_button
  206. panel.add(pane, BorderLayout::WEST)
  207. panel.add @message
  208. return panel
  209. end
  210. def createTrainButtons
  211. kernel_choice = JComboBox.new
  212. ["linear", "RBF", "polynomial", "sigmoid"].each do |item|
  213. kernel_choice.add_item item
  214. end
  215. cost_choice = JTextField.new(10)
  216. cost_choice.text = "1.0"
  217. cost_choice.setMaximumSize(cost_choice.getPreferredSize)
  218. gamma_choice = JTextField.new(10)
  219. gamma_choice.text = "1.0"
  220. gamma_choice.setMaximumSize(gamma_choice.getPreferredSize)
  221. gamma_choice.enabled = false
  222. degree_choice = JSpinner.new(SpinnerNumberModel.new(1, 0, 30, 1))
  223. degree_choice.enabled = false
  224. kernel_choice.add_action_listener KernelChoiceListener.new(kernel_choice, gamma_choice, degree_choice)
  225. run_button = JButton.new "Train"
  226. run_button.add_action_listener do
  227. # -- kernel
  228. case kernel_choice.selected_item
  229. when "linear"
  230. kernel = Parameter::LINEAR
  231. when "RBF"
  232. kernel = Parameter::RBF
  233. when "polynomial"
  234. kernel = Parameter::POLY
  235. when "sigmoid"
  236. kernel = Parameter::SIGMOID
  237. end
  238. # -- cost
  239. begin
  240. cost = Float cost_choice.text
  241. rescue ArgumentError
  242. JOptionPane.show_message_dialog(self,
  243. "Cost value #{cost_choice.text} is not a number",
  244. "Error in cost value",
  245. JOptionPane::ERROR_MESSAGE)
  246. return
  247. end
  248. # -- gamma
  249. begin
  250. gamma = Float gamma_choice.text
  251. rescue ArgumentError
  252. JOptionPane.show_message_dialog(self,
  253. "Gamma value #{gamma_choice.text} is not a number",
  254. "Error in gamma value",
  255. JOptionPane::ERROR_MESSAGE)
  256. return
  257. end
  258. # -- degree
  259. degree = degree_choice.model.number
  260. #
  261. @message.text = "Training and updating display: Please wait"
  262. swt = MySwingWorker.new
  263. swt.task = lambda do
  264. run_button.enabled = false
  265. @display_panel.train(kernel, cost, gamma, degree)
  266. @message.text = ""
  267. run_button.enabled = true
  268. end
  269. swt.execute
  270. end
  271. panel = JPanel.new
  272. panel.border = TitledBorder.new("Training options")
  273. panel.layout = GridLayout.new(5, 2, 10, 10)
  274. panel.add JLabel.new("Kernel type:", JLabel::RIGHT)
  275. panel.add kernel_choice
  276. panel.add JLabel.new("Cost:", JLabel::RIGHT)
  277. panel.add cost_choice
  278. panel.add JLabel.new("Gamma:", JLabel::RIGHT)
  279. panel.add gamma_choice
  280. panel.add JLabel.new("Degree:", JLabel::RIGHT)
  281. panel.add degree_choice
  282. panel.add JLabel.new ""
  283. panel.add run_button
  284. pane = JPanel.new
  285. pane.add panel
  286. return pane
  287. end
  288. class MySwingWorker < javax.swing.SwingWorker
  289. attr_accessor :task
  290. def doInBackground
  291. @task.call
  292. end
  293. end
  294. end
  295. javax.swing::UIManager.getInstalledLookAndFeels.each do |info|
  296. begin
  297. if "Nimbus" == info.name
  298. javax.swing::UIManager.setLookAndFeel(info.className)
  299. end
  300. rescue Exception
  301. # ignore exceptions
  302. end
  303. end
  304. DemoFrame.new.visible = true