123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452 |
- # This class holds the confusion matrix information.
- # It is designed to be called incrementally, as results are obtained
- # from the classifier model.
- #
- # At any point, statistics may be obtained by calling the relevant methods.
- #
- # A two-class example is:
- #
- # Classified Classified |
- # Positive Negative | Actual
- # ------------------------------+------------
- # a b | Positive
- # c d | Negative
- #
- # Statistical methods will be described with reference to this example.
- #
- class ConfusionMatrix
- # Creates a new, empty instance of a confusion matrix.
- #
- # @param labels [<String, Symbol>, ...] if provided, makes the matrix
- # use the first label as a default label, and also check
- # all operations use one of the pre-defined labels.
- # @raise [ArgumentError] if there are not at least two unique labels, when provided.
- def initialize(*labels)
- @matrix = {}
- @labels = labels.uniq
- if @labels.size == 1
- raise ArgumentError.new("If labels are provided, there must be at least two.")
- else # preset the matrix Hash
- @labels.each do |actual|
- @matrix[actual] = {}
- @labels.each do |predicted|
- @matrix[actual][predicted] = 0
- end
- end
- end
- end
- # Returns a list of labels used in the matrix.
- #
- # cm = ConfusionMatrix.new
- # cm.add_for(:pos, :neg)
- # cm.labels # => [:neg, :pos]
- #
- # @return [Array<String>] labels used in the matrix.
- def labels
- if @labels.size >= 2 # if we defined some labels, return them
- @labels
- else
- result = []
- @matrix.each_pair do |key, predictions|
- result << key
- predictions.each_key do |key|
- result << key
- end
- end
- result.uniq.sort
- end
- end
- # Return the count for (actual,prediction) pair.
- #
- # cm = ConfusionMatrix.new
- # cm.add_for(:pos, :neg)
- # cm.count_for(:pos, :neg) # => 1
- #
- # @param actual [String, Symbol] is actual class of the instance,
- # which we expect the classifier to predict
- # @param prediction [String, Symbol] is the predicted class of the instance,
- # as output from the classifier
- # @return [Integer] number of observations of (actual, prediction) pair
- # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
- # pre-defined labels in matrix
- def count_for(actual, prediction)
- validate_label actual, prediction
- predictions = @matrix.fetch(actual, {})
- predictions.fetch(prediction, 0)
- end
- # Adds one result to the matrix for a given (actual, prediction) pair of labels.
- # If the matrix was given a pre-defined list of labels on construction, then
- # these given labels must be from the pre-defined list.
- # If no pre-defined list of labels was used in constructing the matrix, then
- # labels will be added to matrix.
- #
- # Class labels may be any hashable value, though ideally they are strings or symbols.
- #
- # @param actual [String, Symbol] is actual class of the instance,
- # which we expect the classifier to predict
- # @param prediction [String, Symbol] is the predicted class of the instance,
- # as output from the classifier
- # @param n [Integer] number of observations to add
- # @raise [ArgumentError] if +n+ is not an Integer
- # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any
- # pre-defined labels in matrix
- def add_for(actual, prediction, n = 1)
- validate_label actual, prediction
- if !@matrix.has_key?(actual)
- @matrix[actual] = {}
- end
- predictions = @matrix[actual]
- if !predictions.has_key?(prediction)
- predictions[prediction] = 0
- end
- unless n.class == Integer and n.positive?
- raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
- end
- @matrix[actual][prediction] += n
- end
- # Returns the number of instances of the given class label which
- # are incorrectly classified.
- #
- # false_negative(:positive) = b
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of false negative
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def false_negative(label = @labels.first)
- validate_label label
- predictions = @matrix.fetch(label, {})
- total = 0
- predictions.each_pair do |key, count|
- if key != label
- total += count
- end
- end
- total
- end
- # Returns the number of instances incorrectly classified with the given
- # class label.
- #
- # false_positive(:positive) = c
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of false positive
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def false_positive(label = @labels.first)
- validate_label label
- total = 0
- @matrix.each_pair do |key, predictions|
- if key != label
- total += predictions.fetch(label, 0)
- end
- end
- total
- end
- # The false rate for a given class label is the proportion of instances
- # incorrectly classified as that label, out of all those instances
- # not originally of that label.
- #
- # false_rate(:positive) = c/(c+d)
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of false rate
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def false_rate(label = @labels.first)
- validate_label label
- fp = false_positive(label)
- tn = true_negative(label)
- divide(fp, fp+tn)
- end
- # The F-measure for a given label is the harmonic mean of the precision
- # and recall for that label.
- #
- # F = 2*(precision*recall)/(precision+recall)
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of F-measure
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def f_measure(label = @labels.first)
- validate_label label
- 2*precision(label)*recall(label)/(precision(label) + recall(label))
- end
- # The geometric mean is the nth-root of the product of the true_rate for
- # each label.
- #
- # a1 = a/(a+b)
- # a2 = d/(c+d)
- # geometric_mean = Math.sqrt(a1*a2)
- #
- # @return [Float] value of geometric mean
- def geometric_mean
- product = 1
- @matrix.each_key do |key|
- product *= true_rate(key)
- end
- product**(1.0/@matrix.size)
- end
- # The Kappa statistic compares the observed accuracy with an expected
- # accuracy.
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of Cohen's Kappa Statistic
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def kappa(label = @labels.first)
- validate_label label
- tp = true_positive(label)
- fn = false_negative(label)
- fp = false_positive(label)
- tn = true_negative(label)
- total = tp+fn+fp+tn
- total_accuracy = divide(tp+tn, tp+tn+fp+fn)
- random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)
- divide(total_accuracy - random_accuracy, 1 - random_accuracy)
- end
- # Matthews Correlation Coefficient is a measure of the quality of binary
- # classifications.
- #
- # mathews_correlation(:positive) = (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of Matthews Correlation Coefficient
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def matthews_correlation(label = @labels.first)
- validate_label label
- tp = true_positive(label)
- fn = false_negative(label)
- fp = false_positive(label)
- tn = true_negative(label)
- divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
- end
- # The overall accuracy is the proportion of instances which are
- # correctly labelled.
- #
- # overall_accuracy = (a+d)/(a+b+c+d)
- #
- # @return [Float] value of overall accuracy
- def overall_accuracy
- total_correct = 0
- @matrix.each_pair do |key, predictions|
- total_correct += true_positive(key)
- end
- divide(total_correct, total)
- end
- # The precision for a given class label is the proportion of instances
- # classified as that class which are correct.
- #
- # precision(:positive) = a/(a+c)
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of precision
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def precision(label = @labels.first)
- validate_label label
- tp = true_positive(label)
- fp = false_positive(label)
- divide(tp, tp+fp)
- end
- # The prevalence for a given class label is the proportion of instances
- # which were classified as of that label, out of the total.
- #
- # prevalence(:positive) = (a+c)/(a+b+c+d)
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of prevalence
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def prevalence(label = @labels.first)
- validate_label label
- tp = true_positive(label)
- fn = false_negative(label)
- fp = false_positive(label)
- tn = true_negative(label)
- total = tp+fn+fp+tn
- divide(tp+fn, total)
- end
- # The recall is another name for the true rate.
- #
- # @see true_rate
- # @param (see #true_rate)
- # @return (see #true_rate)
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def recall(label = @labels.first)
- validate_label label
- true_rate(label)
- end
- # Sensitivity is another name for the true rate.
- #
- # @see true_rate
- # @param (see #true_rate)
- # @return (see #true_rate)
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def sensitivity(label = @labels.first)
- validate_label label
- true_rate(label)
- end
- # The specificity for a given class label is 1 - false_rate(label)
- #
- # In two-class case, specificity = 1 - false_positive_rate
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] value of specificity
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def specificity(label = @labels.first)
- validate_label label
- 1-false_rate(label)
- end
- # Returns the table in a string format, representing the entries as a
- # printable table.
- #
- # @return [String] representation as a printable table.
- def to_s
- ls = labels
- result = ""
- title_line = "Predicted "
- label_line = ""
- ls.each { |l| label_line << "#{l} " }
- label_line << " " while label_line.size < title_line.size
- title_line << " " while title_line.size < label_line.size
- result << title_line << "|\n" << label_line << "| Actual\n"
- result << "-"*title_line.size << "+-------\n"
- ls.each do |l|
- count_line = ""
- ls.each_with_index do |m, i|
- count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
- end
- result << count_line.ljust(title_line.size) << "| #{l}\n"
- end
- result
- end
- # Returns the total number of instances referenced in the matrix.
- #
- # total = a+b+c+d
- #
- # @return [Integer] total number of instances referenced in the matrix.
- def total
- total = 0
- @matrix.each_value do |predictions|
- predictions.each_value do |count|
- total += count
- end
- end
- total
- end
- # Returns the number of instances NOT of the given class label which
- # are correctly classified.
- #
- # true_negative(:positive) = d
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Integer] number of instances not of given label which are correctly classified
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def true_negative(label = @labels.first)
- validate_label label
- total = 0
- @matrix.each_pair do |key, predictions|
- if key != label
- total += predictions.fetch(key, 0)
- end
- end
- total
- end
- # Returns the number of instances of the given class label which are
- # correctly classified.
- #
- # true_positive(:positive) = a
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Integer] number of instances of given label which are correctly classified
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def true_positive(label = @labels.first)
- validate_label label
- predictions = @matrix.fetch(label, {})
- predictions.fetch(label, 0)
- end
- # The true rate for a given class label is the proportion of instances of
- # that class which are correctly classified.
- #
- # true_rate(:positive) = a/(a+b)
- #
- # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
- # @return [Float] proportion of instances which are correctly classified
- # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
- def true_rate(label = @labels.first)
- validate_label label
- tp = true_positive(label)
- fn = false_negative(label)
- divide(tp, tp+fn)
- end
- private
- # A form of "safe divide".
- # Checks if divisor is zero, and returns 0.0 if so.
- # This avoids a run-time error.
- # Also, ensures floating point division is done.
- def divide(x, y)
- if y.zero?
- 0.0
- else
- x.to_f/y
- end
- end
- # Checks if given label(s) is non-nil and in @labels, or if @labels is empty
- # Raises ArgumentError if not
- def validate_label *labels
- return true if @labels.empty?
- labels.each do |label|
- unless label and @labels.include?(label)
- raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
- end
- end
- end
- end
|