matrix_test.rb 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. require 'confusion_matrix'
  2. require 'minitest/autorun'
  3. class TestConfusionMatrix < MiniTest::Test
  4. def test_empty_case
  5. cm = ConfusionMatrix.new
  6. assert(0, cm.total)
  7. assert(0, cm.true_positive(:none))
  8. assert(0, cm.false_negative(:none))
  9. assert(0, cm.false_positive(:none))
  10. assert(0, cm.true_negative(:none))
  11. assert_in_delta(0, cm.true_rate(:none))
  12. end
  13. def test_two_classes
  14. cm = ConfusionMatrix.new
  15. 10.times { cm.add_for(:pos, :pos) }
  16. 5.times { cm.add_for(:pos, :neg) }
  17. 20.times { cm.add_for(:neg, :neg) }
  18. 5.times { cm.add_for(:neg, :pos) }
  19. assert_equal([:neg, :pos], cm.labels)
  20. assert_equal(10, cm.count_for(:pos, :pos))
  21. assert_equal(5, cm.count_for(:pos, :neg))
  22. assert_equal(20, cm.count_for(:neg, :neg))
  23. assert_equal(5, cm.count_for(:neg, :pos))
  24. assert_equal(40, cm.total)
  25. assert_equal(10, cm.true_positive(:pos))
  26. assert_equal(5, cm.false_negative(:pos))
  27. assert_equal(5, cm.false_positive(:pos))
  28. assert_equal(20, cm.true_negative(:pos))
  29. assert_equal(20, cm.true_positive(:neg))
  30. assert_equal(5, cm.false_negative(:neg))
  31. assert_equal(5, cm.false_positive(:neg))
  32. assert_equal(10, cm.true_negative(:neg))
  33. assert_in_delta(0.6667, cm.true_rate(:pos))
  34. assert_in_delta(0.8, cm.true_rate(:neg))
  35. assert_in_delta(0.2, cm.false_rate(:pos))
  36. assert_in_delta(0.3333, cm.false_rate(:neg))
  37. assert_in_delta(0.6667, cm.precision(:pos))
  38. assert_in_delta(0.8, cm.precision(:neg))
  39. assert_in_delta(0.6667, cm.recall(:pos))
  40. assert_in_delta(0.8, cm.recall(:neg))
  41. assert_in_delta(0.6667, cm.sensitivity(:pos))
  42. assert_in_delta(0.8, cm.sensitivity(:neg))
  43. assert_in_delta(0.75, cm.overall_accuracy)
  44. assert_in_delta(0.6667, cm.f_measure(:pos))
  45. assert_in_delta(0.8, cm.f_measure(:neg))
  46. assert_in_delta(0.7303, cm.geometric_mean)
  47. end
  48. # Example from:
  49. # https://www.datatechnotes.com/2019/02/accuracy-metrics-in-classification.html
  50. def test_two_classes_2
  51. cm = ConfusionMatrix.new
  52. 5.times { cm.add_for(:pos, :pos) }
  53. 1.times { cm.add_for(:pos, :neg) }
  54. 3.times { cm.add_for(:neg, :neg) }
  55. 2.times { cm.add_for(:neg, :pos) }
  56. assert_equal(11, cm.total)
  57. assert_equal(5, cm.true_positive(:pos))
  58. assert_equal(1, cm.false_negative(:pos))
  59. assert_equal(2, cm.false_positive(:pos))
  60. assert_equal(3, cm.true_negative(:pos))
  61. assert_in_delta(0.7142, cm.precision(:pos))
  62. assert_in_delta(0.8333, cm.recall(:pos))
  63. assert_in_delta(0.7272, cm.overall_accuracy)
  64. assert_in_delta(0.7692, cm.f_measure(:pos))
  65. assert_in_delta(0.8333, cm.sensitivity(:pos))
  66. assert_in_delta(0.6, cm.specificity(:pos))
  67. assert_in_delta(0.4407, cm.kappa(:pos))
  68. assert_in_delta(0.5454, cm.prevalence(:pos))
  69. end
  70. # Examples from:
  71. # https://standardwisdom.com/softwarejournal/2011/12/matthews-correlation-coefficient-how-well-does-it-do/
  72. def two_class_case(a,b,c,d,e,f,g,h,i)
  73. cm = ConfusionMatrix.new
  74. a.times { cm.add_for(:pos, :pos) }
  75. b.times { cm.add_for(:pos, :neg) }
  76. c.times { cm.add_for(:neg, :neg) }
  77. d.times { cm.add_for(:neg, :pos) }
  78. assert_in_delta(e, cm.matthews_correlation(:pos))
  79. assert_in_delta(f, cm.precision(:pos))
  80. assert_in_delta(g, cm.recall(:pos))
  81. assert_in_delta(h, cm.f_measure(:pos))
  82. assert_in_delta(i, cm.kappa(:pos))
  83. end
  84. def test_two_classes_3
  85. two_class_case(100, 0, 900, 0, 1.0, 1.0, 1.0, 1.0, 1.0)
  86. two_class_case(65, 35, 825, 75, 0.490, 0.4643, 0.65, 0.542, 0.4811)
  87. two_class_case(50, 50, 700, 200, 0.192, 0.2, 0.5, 0.286, 0.1666)
  88. end
  89. def test_three_classes
  90. cm = ConfusionMatrix.new
  91. 10.times { cm.add_for(:red, :red) }
  92. 7.times { cm.add_for(:red, :blue) }
  93. 5.times { cm.add_for(:red, :green) }
  94. 20.times { cm.add_for(:blue, :red) }
  95. 5.times { cm.add_for(:blue, :blue) }
  96. 15.times { cm.add_for(:blue, :green) }
  97. 30.times { cm.add_for(:green, :red) }
  98. 12.times { cm.add_for(:green, :blue) }
  99. 8.times { cm.add_for(:green, :green) }
  100. assert_equal([:blue, :green, :red], cm.labels)
  101. assert_equal(112, cm.total)
  102. assert_equal(10, cm.true_positive(:red))
  103. assert_equal(12, cm.false_negative(:red))
  104. assert_equal(50, cm.false_positive(:red))
  105. assert_equal(13, cm.true_negative(:red))
  106. assert_equal(5, cm.true_positive(:blue))
  107. assert_equal(35, cm.false_negative(:blue))
  108. assert_equal(19, cm.false_positive(:blue))
  109. assert_equal(18, cm.true_negative(:blue))
  110. assert_equal(8, cm.true_positive(:green))
  111. assert_equal(42, cm.false_negative(:green))
  112. assert_equal(20, cm.false_positive(:green))
  113. assert_equal(15, cm.true_negative(:green))
  114. end
  115. def test_add_for_n
  116. cm = ConfusionMatrix.new
  117. cm.add_for(:pos, :pos, 3)
  118. cm.add_for(:pos, :neg)
  119. cm.add_for(:neg, :pos, 2)
  120. cm.add_for(:neg, :neg, 1)
  121. assert_equal(7, cm.total)
  122. assert_equal(3, cm.count_for(:pos, :pos))
  123. # - check errors
  124. assert_raises(ArgumentError) { cm.add_for(:pos, :pos, 0) }
  125. assert_raises(ArgumentError) { cm.add_for(:pos, :pos, -3) }
  126. assert_raises(ArgumentError) { cm.add_for(:pos, :pos, nil) }
  127. end
  128. def test_use_labels
  129. # - check errors
  130. assert_raises(ArgumentError) { ConfusionMatrix.new(:pos) }
  131. assert_raises(ArgumentError) { ConfusionMatrix.new(:pos, :pos) }
  132. # - check created matrix
  133. cm = ConfusionMatrix.new(:pos, :neg)
  134. assert_equal([:pos, :neg], cm.labels)
  135. assert_raises(ArgumentError) { cm.add_for(:pos, :nothing) }
  136. cm.add_for(:pos, :neg, 3)
  137. cm.add_for(:neg, :pos, 2)
  138. assert_equal(2, cm.false_negative(:neg))
  139. assert_equal(3, cm.false_negative(:pos))
  140. assert_equal(3, cm.false_negative())
  141. assert_raises(ArgumentError) { cm.false_negative(:nothing) }
  142. assert_raises(ArgumentError) { cm.false_negative(nil) }
  143. end
  144. end