|
@@ -9,6 +9,7 @@
|
|
|
|
|
|
use std::collections::HashMap;
|
|
|
use std::fmt;
|
|
|
+use std::ops;
|
|
|
|
|
|
/// A confusion matrix is used to record pairs of (actual class, predicted class)
|
|
|
/// as typically produced by a classification algorithm.
|
|
@@ -108,6 +109,7 @@ use std::fmt;
|
|
|
/// 5 10 | pos
|
|
|
/// ```
|
|
|
///
|
|
|
+#[derive(Clone,Debug,Default,Eq,PartialEq)]
|
|
|
pub struct ConfusionMatrix {
|
|
|
matrix: HashMap<String, HashMap<String, usize>>,
|
|
|
}
|
|
@@ -167,6 +169,40 @@ impl fmt::Display for ConfusionMatrix {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+impl ops::Index<(&str, &str)> for ConfusionMatrix {
|
|
|
+ type Output = usize;
|
|
|
+
|
|
|
+ /// Returns the count for an (actual, prediction) pair, or 0 if the pair
|
|
|
+ /// is not known.
|
|
|
+ ///
|
|
|
+ /// * `actual` - the actual class of the instance, which we are hoping
|
|
|
+ /// the classifier will predict.
|
|
|
+ /// * `prediction` - the predicted class for the instance, as output from
|
|
|
+ /// the classifier.
|
|
|
+ ///
|
|
|
+ /// # Example
|
|
|
+ /// (using `cm` from [`Self::add_for()`])
|
|
|
+ ///
|
|
|
+ /// ```
|
|
|
+ /// # let mut cm = confusion_matrix::new();
|
|
|
+ /// # for _ in 0..2 { cm.add_for("positive", "positive"); }
|
|
|
+ /// # for _ in 0..5 { cm.add_for("positive", "negative"); }
|
|
|
+ /// # for _ in 0..1 { cm.add_for("negative", "positive"); }
|
|
|
+ /// # for _ in 0..3 { cm.add_for("negative", "negative"); }
|
|
|
+ /// assert_eq!(2, cm[("positive", "positive")]);
|
|
|
+ /// assert_eq!(0, cm[("positive", "not_known")]);
|
|
|
+ /// ```
|
|
|
+ ///
|
|
|
+ fn index(&self, (actual, prediction): (&str, &str)) -> &usize {
|
|
|
+ if let Some(predictions) = self.matrix.get(actual) {
|
|
|
+ if let Some(count) = predictions.get(prediction) {
|
|
|
+ return count;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ &0
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
impl ConfusionMatrix {
|
|
|
/// Adds one result to the matrix.
|
|
|
///
|
|
@@ -681,10 +717,10 @@ mod tests {
|
|
|
}
|
|
|
|
|
|
assert_eq!(vec!["neg", "pos"], cm.labels());
|
|
|
- assert_eq!(10, cm.count_for("pos", "pos"));
|
|
|
- assert_eq!(5, cm.count_for("pos", "neg"));
|
|
|
- assert_eq!(20, cm.count_for("neg", "neg"));
|
|
|
- assert_eq!(5, cm.count_for("neg", "pos"));
|
|
|
+ assert_eq!(10, cm[("pos", "pos")]);
|
|
|
+ assert_eq!(5, cm[("pos", "neg")]);
|
|
|
+ assert_eq!(20, cm[("neg", "neg")]);
|
|
|
+ assert_eq!(5, cm[("neg", "pos")]);
|
|
|
|
|
|
assert_eq!(40, cm.total());
|
|
|
assert_eq!(10, cm.true_positive("pos"));
|
|
@@ -833,5 +869,39 @@ mod tests {
|
|
|
assert_eq!(20, cm.false_positive("green"));
|
|
|
assert_eq!(15, cm.true_negative("green"));
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn check_traits_clone_eq() {
|
|
|
+ let mut cm = new();
|
|
|
+ for _ in 0..3 {
|
|
|
+ cm.add_for("pos", "pos");
|
|
|
+ }
|
|
|
+ for _ in 0..2 {
|
|
|
+ cm.add_for("pos", "neg");
|
|
|
+ }
|
|
|
+ for _ in 0..4 {
|
|
|
+ cm.add_for("neg", "neg");
|
|
|
+ }
|
|
|
+ for _ in 0..5 {
|
|
|
+ cm.add_for("neg", "pos");
|
|
|
+ }
|
|
|
+ let mut cm_clone = cm.clone();
|
|
|
+
|
|
|
+ assert_eq!(cm.total(), cm_clone.total());
|
|
|
+ assert_eq!(14, cm.total());
|
|
|
+ assert_eq!(14, cm_clone.total());
|
|
|
+ assert_eq!(cm, cm_clone);
|
|
|
+
|
|
|
+ cm_clone.add_for("pos", "pos");
|
|
|
+ assert_eq!(14, cm.total());
|
|
|
+ assert_eq!(15, cm_clone.total());
|
|
|
+ assert!(cm != cm_clone);
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn check_default() {
|
|
|
+ let cm = super::ConfusionMatrix::default();
|
|
|
+ assert_eq!(0, cm.total());
|
|
|
+ }
|
|
|
}
|
|
|
|