confusion_matrix.rb 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # This class holds the confusion matrix information.
  2. # It is designed to be called incrementally, as results are obtained
  3. # from the classifier model.
  4. #
  5. # At any point, statistics may be obtained by calling the relevant methods.
  6. #
  7. # A two-class example is:
  8. #
  9. # Classified Classified |
  10. # Positive Negative | Actual
  11. # ------------------------------+------------
  12. # a b | Positive
  13. # c d | Negative
  14. #
  15. # Statistical methods will be described with reference to this example.
  16. #
  17. class ConfusionMatrix
  18. # Creates a new, empty instance of a confusion matrix.
  19. #
  20. # @param labels [<String, Symbol>, ...] if provided, makes the matrix
  21. # use the first label as a default label, and also check
  22. # all operations use one of the pre-defined labels.
  23. # @raise [ArgumentError] if there are not at least two unique labels, when provided.
  24. def initialize(*labels)
  25. @matrix = {}
  26. @labels = labels.uniq
  27. if @labels.size == 1
  28. raise ArgumentError.new("If labels are provided, there must be at least two.")
  29. else # preset the matrix Hash
  30. @labels.each do |actual|
  31. @matrix[actual] = {}
  32. @labels.each do |predicted|
  33. @matrix[actual][predicted] = 0
  34. end
  35. end
  36. end
  37. end
  38. # Returns a list of labels used in the matrix.
  39. #
  40. # cm = ConfusionMatrix.new
  41. # cm.add_for(:pos, :neg)
  42. # cm.labels # => [:neg, :pos]
  43. #
  44. # @return [Array<String>] labels used in the matrix.
  45. def labels
  46. if @labels.size >= 2 # if we defined some labels, return them
  47. @labels
  48. else
  49. result = []
  50. @matrix.each_pair do |key, predictions|
  51. result << key
  52. predictions.each_key do |key|
  53. result << key
  54. end
  55. end
  56. result.uniq.sort
  57. end
  58. end
  59. # Return the count for (actual,prediction) pair.
  60. #
  61. # cm = ConfusionMatrix.new
  62. # cm.add_for(:pos, :neg)
  63. # cm.count_for(:pos, :neg) # => 1
  64. #
  65. # @param actual [String, Symbol] is actual class of the instance,
  66. # which we expect the classifier to predict
  67. # @param prediction [String, Symbol] is the predicted class of the instance,
  68. # as output from the classifier
  69. # @return [Integer] number of observations of (actual, prediction) pair
  70. # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
  71. # pre-defined labels in matrix
  72. def count_for(actual, prediction)
  73. validate_label actual, prediction
  74. predictions = @matrix.fetch(actual, {})
  75. predictions.fetch(prediction, 0)
  76. end
  77. # Adds one result to the matrix for a given (actual, prediction) pair of labels.
  78. # If the matrix was given a pre-defined list of labels on construction, then
  79. # these given labels must be from the pre-defined list.
  80. # If no pre-defined list of labels was used in constructing the matrix, then
  81. # labels will be added to matrix.
  82. #
  83. # Class labels may be any hashable value, though ideally they are strings or symbols.
  84. #
  85. # @param actual [String, Symbol] is actual class of the instance,
  86. # which we expect the classifier to predict
  87. # @param prediction [String, Symbol] is the predicted class of the instance,
  88. # as output from the classifier
  89. # @param n [Integer] number of observations to add
  90. # @raise [ArgumentError] if +n+ is not an Integer
  91. # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
  92. # pre-defined labels in matrix
  93. def add_for(actual, prediction, n = 1)
  94. validate_label actual, prediction
  95. if !@matrix.has_key?(actual)
  96. @matrix[actual] = {}
  97. end
  98. predictions = @matrix[actual]
  99. if !predictions.has_key?(prediction)
  100. predictions[prediction] = 0
  101. end
  102. unless n.class == Integer and n.positive?
  103. raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
  104. end
  105. @matrix[actual][prediction] += n
  106. end
  107. # Returns the number of instances of the given class label which
  108. # are incorrectly classified.
  109. #
  110. # false_negative(:positive) = b
  111. #
  112. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  113. # @return [Float] value of false negative
  114. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  115. def false_negative(label = @labels.first)
  116. validate_label label
  117. predictions = @matrix.fetch(label, {})
  118. total = 0
  119. predictions.each_pair do |key, count|
  120. if key != label
  121. total += count
  122. end
  123. end
  124. total
  125. end
  126. # Returns the number of instances incorrectly classified with the given
  127. # class label.
  128. #
  129. # false_positive(:positive) = c
  130. #
  131. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  132. # @return [Float] value of false positive
  133. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  134. def false_positive(label = @labels.first)
  135. validate_label label
  136. total = 0
  137. @matrix.each_pair do |key, predictions|
  138. if key != label
  139. total += predictions.fetch(label, 0)
  140. end
  141. end
  142. total
  143. end
  144. # The false rate for a given class label is the proportion of instances
  145. # incorrectly classified as that label, out of all those instances
  146. # not originally of that label.
  147. #
  148. # false_rate(:positive) = c/(c+d)
  149. #
  150. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  151. # @return [Float] value of false rate
  152. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  153. def false_rate(label = @labels.first)
  154. validate_label label
  155. fp = false_positive(label)
  156. tn = true_negative(label)
  157. divide(fp, fp+tn)
  158. end
  159. # The F-measure for a given label is the harmonic mean of the precision
  160. # and recall for that label.
  161. #
  162. # F = 2*(precision*recall)/(precision+recall)
  163. #
  164. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  165. # @return [Float] value of F-measure
  166. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  167. def f_measure(label = @labels.first)
  168. validate_label label
  169. 2*precision(label)*recall(label)/(precision(label) + recall(label))
  170. end
  171. # The geometric mean is the nth-root of the product of the true_rate for
  172. # each label.
  173. #
  174. # a1 = a/(a+b)
  175. # a2 = d/(c+d)
  176. # geometric_mean = Math.sqrt(a1*a2)
  177. #
  178. # @return [Float] value of geometric mean
  179. def geometric_mean
  180. product = 1
  181. @matrix.each_key do |key|
  182. product *= true_rate(key)
  183. end
  184. product**(1.0/@matrix.size)
  185. end
  186. # The Kappa statistic compares the observed accuracy with an expected
  187. # accuracy.
  188. #
  189. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  190. # @return [Float] value of Cohen's Kappa Statistic
  191. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  192. def kappa(label = @labels.first)
  193. validate_label label
  194. tp = true_positive(label)
  195. fn = false_negative(label)
  196. fp = false_positive(label)
  197. tn = true_negative(label)
  198. total = tp+fn+fp+tn
  199. total_accuracy = divide(tp+tn, tp+tn+fp+fn)
  200. random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
  201. divide(total_accuracy - random_accuracy, 1 - random_accuracy)
  202. end
  203. # Matthews Correlation Coefficient is a measure of the quality of binary
  204. # classifications.
  205. #
  206. # mathews_correlation(:positive) = (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))
  207. #
  208. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  209. # @return [Float] value of Matthews Correlation Coefficient
  210. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  211. def matthews_correlation(label = @labels.first)
  212. validate_label label
  213. tp = true_positive(label)
  214. fn = false_negative(label)
  215. fp = false_positive(label)
  216. tn = true_negative(label)
  217. divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
  218. end
  219. # The overall accuracy is the proportion of instances which are
  220. # correctly labelled.
  221. #
  222. # overall_accuracy = (a+d)/(a+b+c+d)
  223. #
  224. # @return [Float] value of overall accuracy
  225. def overall_accuracy
  226. total_correct = 0
  227. @matrix.each_pair do |key, predictions|
  228. total_correct += true_positive(key)
  229. end
  230. divide(total_correct, total)
  231. end
  232. # The precision for a given class label is the proportion of instances
  233. # classified as that class which are correct.
  234. #
  235. # precision(:positive) = a/(a+c)
  236. #
  237. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  238. # @return [Float] value of precision
  239. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  240. def precision(label = @labels.first)
  241. validate_label label
  242. tp = true_positive(label)
  243. fp = false_positive(label)
  244. divide(tp, tp+fp)
  245. end
  246. # The prevalence for a given class label is the proportion of instances
  247. # which were classified as of that label, out of the total.
  248. #
  249. # prevalence(:positive) = (a+c)/(a+b+c+d)
  250. #
  251. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  252. # @return [Float] value of prevalence
  253. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  254. def prevalence(label = @labels.first)
  255. validate_label label
  256. tp = true_positive(label)
  257. fn = false_negative(label)
  258. fp = false_positive(label)
  259. tn = true_negative(label)
  260. total = tp+fn+fp+tn
  261. divide(tp+fn, total)
  262. end
  263. # The recall is another name for the true rate.
  264. #
  265. # @see true_rate
  266. # @param (see #true_rate)
  267. # @return (see #true_rate)
  268. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  269. def recall(label = @labels.first)
  270. validate_label label
  271. true_rate(label)
  272. end
  273. # Sensitivity is another name for the true rate.
  274. #
  275. # @see true_rate
  276. # @param (see #true_rate)
  277. # @return (see #true_rate)
  278. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  279. def sensitivity(label = @labels.first)
  280. validate_label label
  281. true_rate(label)
  282. end
  283. # The specificity for a given class label is 1 - false_rate(label)
  284. #
  285. # In two-class case, specificity = 1 - false_positive_rate
  286. #
  287. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  288. # @return [Float] value of specificity
  289. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  290. def specificity(label = @labels.first)
  291. validate_label label
  292. 1-false_rate(label)
  293. end
  294. # Returns the table in a string format, representing the entries as a
  295. # printable table.
  296. #
  297. # @return [String] representation as a printable table.
  298. def to_s
  299. ls = labels
  300. result = ""
  301. title_line = "Predicted "
  302. label_line = ""
  303. ls.each { |l| label_line << "#{l} " }
  304. label_line << " " while label_line.size < title_line.size
  305. title_line << " " while title_line.size < label_line.size
  306. result << title_line << "|\n" << label_line << "| Actual\n"
  307. result << "-"*title_line.size << "+-------\n"
  308. ls.each do |l|
  309. count_line = ""
  310. ls.each_with_index do |m, i|
  311. count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
  312. end
  313. result << count_line.ljust(title_line.size) << "| #{l}\n"
  314. end
  315. result
  316. end
  317. # Returns the total number of instances referenced in the matrix.
  318. #
  319. # total = a+b+c+d
  320. #
  321. # @return [Integer] total number of instances referenced in the matrix.
  322. def total
  323. total = 0
  324. @matrix.each_value do |predictions|
  325. predictions.each_value do |count|
  326. total += count
  327. end
  328. end
  329. total
  330. end
  331. # Returns the number of instances NOT of the given class label which
  332. # are correctly classified.
  333. #
  334. # true_negative(:positive) = d
  335. #
  336. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  337. # @return [Integer] number of instances not of given label which are correctly classified
  338. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  339. def true_negative(label = @labels.first)
  340. validate_label label
  341. total = 0
  342. @matrix.each_pair do |key, predictions|
  343. if key != label
  344. total += predictions.fetch(key, 0)
  345. end
  346. end
  347. total
  348. end
  349. # Returns the number of instances of the given class label which are
  350. # correctly classified.
  351. #
  352. # true_positive(:positive) = a
  353. #
  354. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  355. # @return [Integer] number of instances of given label which are correctly classified
  356. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  357. def true_positive(label = @labels.first)
  358. validate_label label
  359. predictions = @matrix.fetch(label, {})
  360. predictions.fetch(label, 0)
  361. end
  362. # The true rate for a given class label is the proportion of instances of
  363. # that class which are correctly classified.
  364. #
  365. # true_rate(:positive) = a/(a+b)
  366. #
  367. # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  368. # @return [Float] proportion of instances which are correctly classified
  369. # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  370. def true_rate(label = @labels.first)
  371. validate_label label
  372. tp = true_positive(label)
  373. fn = false_negative(label)
  374. divide(tp, tp+fn)
  375. end
  376. private
  377. # A form of "safe divide".
  378. # Checks if divisor is zero, and returns 0.0 if so.
  379. # This avoids a run-time error.
  380. # Also, ensures floating point division is done.
  381. def divide(x, y)
  382. if y.zero?
  383. 0.0
  384. else
  385. x.to_f/y
  386. end
  387. end
  388. # Checks if given label(s) is non-nil and in @labels, or if @labels is empty
  389. # Raises ArgumentError if not
  390. def validate_label *labels
  391. return true if @labels.empty?
  392. labels.each do |label|
  393. unless label and @labels.include?(label)
  394. raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
  395. end
  396. end
  397. end
  398. end