123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- require 'confusion_matrix'
- require 'minitest/autorun'
- class TestConfusionMatrix < MiniTest::Test
- def test_empty_case
- cm = ConfusionMatrix.new
- assert(0, cm.total)
- assert(0, cm.true_positive(:none))
- assert(0, cm.false_negative(:none))
- assert(0, cm.false_positive(:none))
- assert(0, cm.true_negative(:none))
- assert_in_delta(0, cm.true_rate(:none))
- end
- def test_two_classes
- cm = ConfusionMatrix.new
- 10.times { cm.add_for(:pos, :pos) }
- 5.times { cm.add_for(:pos, :neg) }
- 20.times { cm.add_for(:neg, :neg) }
- 5.times { cm.add_for(:neg, :pos) }
- assert_equal([:neg, :pos], cm.labels)
- assert_equal(10, cm.count_for(:pos, :pos))
- assert_equal(5, cm.count_for(:pos, :neg))
- assert_equal(20, cm.count_for(:neg, :neg))
- assert_equal(5, cm.count_for(:neg, :pos))
- assert_equal(40, cm.total)
- assert_equal(10, cm.true_positive(:pos))
- assert_equal(5, cm.false_negative(:pos))
- assert_equal(5, cm.false_positive(:pos))
- assert_equal(20, cm.true_negative(:pos))
- assert_equal(20, cm.true_positive(:neg))
- assert_equal(5, cm.false_negative(:neg))
- assert_equal(5, cm.false_positive(:neg))
- assert_equal(10, cm.true_negative(:neg))
- assert_in_delta(0.6667, cm.true_rate(:pos))
- assert_in_delta(0.8, cm.true_rate(:neg))
- assert_in_delta(0.2, cm.false_rate(:pos))
- assert_in_delta(0.3333, cm.false_rate(:neg))
- assert_in_delta(0.6667, cm.precision(:pos))
- assert_in_delta(0.8, cm.precision(:neg))
- assert_in_delta(0.6667, cm.recall(:pos))
- assert_in_delta(0.8, cm.recall(:neg))
- assert_in_delta(0.6667, cm.sensitivity(:pos))
- assert_in_delta(0.8, cm.sensitivity(:neg))
- assert_in_delta(0.75, cm.overall_accuracy)
- assert_in_delta(0.6667, cm.f_measure(:pos))
- assert_in_delta(0.8, cm.f_measure(:neg))
- assert_in_delta(0.7303, cm.geometric_mean)
- end
- # Example from:
- # https://www.datatechnotes.com/2019/02/accuracy-metrics-in-classification.html
- def test_two_classes_2
- cm = ConfusionMatrix.new
- 5.times { cm.add_for(:pos, :pos) }
- 1.times { cm.add_for(:pos, :neg) }
- 3.times { cm.add_for(:neg, :neg) }
- 2.times { cm.add_for(:neg, :pos) }
- assert_equal(11, cm.total)
- assert_equal(5, cm.true_positive(:pos))
- assert_equal(1, cm.false_negative(:pos))
- assert_equal(2, cm.false_positive(:pos))
- assert_equal(3, cm.true_negative(:pos))
- assert_in_delta(0.7142, cm.precision(:pos))
- assert_in_delta(0.8333, cm.recall(:pos))
- assert_in_delta(0.7272, cm.overall_accuracy)
- assert_in_delta(0.7692, cm.f_measure(:pos))
- assert_in_delta(0.8333, cm.sensitivity(:pos))
- assert_in_delta(0.6, cm.specificity(:pos))
- assert_in_delta(0.4407, cm.kappa(:pos))
- assert_in_delta(0.5454, cm.prevalence(:pos))
- end
- # Examples from:
- # https://standardwisdom.com/softwarejournal/2011/12/matthews-correlation-coefficient-how-well-does-it-do/
- def two_class_case(a,b,c,d,e,f,g,h,i)
- cm = ConfusionMatrix.new
- a.times { cm.add_for(:pos, :pos) }
- b.times { cm.add_for(:pos, :neg) }
- c.times { cm.add_for(:neg, :neg) }
- d.times { cm.add_for(:neg, :pos) }
- assert_in_delta(e, cm.matthews_correlation(:pos))
- assert_in_delta(f, cm.precision(:pos))
- assert_in_delta(g, cm.recall(:pos))
- assert_in_delta(h, cm.f_measure(:pos))
- assert_in_delta(i, cm.kappa(:pos))
- end
- def test_two_classes_3
- two_class_case(100, 0, 900, 0, 1.0, 1.0, 1.0, 1.0, 1.0)
- two_class_case(65, 35, 825, 75, 0.490, 0.4643, 0.65, 0.542, 0.4811)
- two_class_case(50, 50, 700, 200, 0.192, 0.2, 0.5, 0.286, 0.1666)
- end
- def test_three_classes
- cm = ConfusionMatrix.new
- 10.times { cm.add_for(:red, :red) }
- 7.times { cm.add_for(:red, :blue) }
- 5.times { cm.add_for(:red, :green) }
- 20.times { cm.add_for(:blue, :red) }
- 5.times { cm.add_for(:blue, :blue) }
- 15.times { cm.add_for(:blue, :green) }
- 30.times { cm.add_for(:green, :red) }
- 12.times { cm.add_for(:green, :blue) }
- 8.times { cm.add_for(:green, :green) }
- assert_equal([:blue, :green, :red], cm.labels)
- assert_equal(112, cm.total)
- assert_equal(10, cm.true_positive(:red))
- assert_equal(12, cm.false_negative(:red))
- assert_equal(50, cm.false_positive(:red))
- assert_equal(13, cm.true_negative(:red))
- assert_equal(5, cm.true_positive(:blue))
- assert_equal(35, cm.false_negative(:blue))
- assert_equal(19, cm.false_positive(:blue))
- assert_equal(18, cm.true_negative(:blue))
- assert_equal(8, cm.true_positive(:green))
- assert_equal(42, cm.false_negative(:green))
- assert_equal(20, cm.false_positive(:green))
- assert_equal(15, cm.true_negative(:green))
- end
- def test_add_for_n
- cm = ConfusionMatrix.new
- cm.add_for(:pos, :pos, 3)
- cm.add_for(:pos, :neg)
- cm.add_for(:neg, :pos, 2)
- cm.add_for(:neg, :neg, 1)
- assert_equal(7, cm.total)
- assert_equal(3, cm.count_for(:pos, :pos))
- # - check errors
- assert_raises(ArgumentError) { cm.add_for(:pos, :pos, 0) }
- assert_raises(ArgumentError) { cm.add_for(:pos, :pos, -3) }
- assert_raises(ArgumentError) { cm.add_for(:pos, :pos, nil) }
- end
- def test_use_labels
- # - check errors
- assert_raises(ArgumentError) { ConfusionMatrix.new(:pos) }
- assert_raises(ArgumentError) { ConfusionMatrix.new(:pos, :pos) }
- # - check created matrix
- cm = ConfusionMatrix.new(:pos, :neg)
- assert_equal([:pos, :neg], cm.labels)
- assert_raises(ArgumentError) { cm.add_for(:pos, :nothing) }
- cm.add_for(:pos, :neg, 3)
- cm.add_for(:neg, :pos, 2)
- assert_equal(2, cm.false_negative(:neg))
- assert_equal(3, cm.false_negative(:pos))
- assert_equal(3, cm.false_negative())
- assert_raises(ArgumentError) { cm.false_negative(:nothing) }
- assert_raises(ArgumentError) { cm.false_negative(nil) }
- end
- end
|