2 Commits 258c846d7d ... ee873f20a2

Author SHA1 Message Date
  Peter Lane ee873f20a2 added support for traits to clone, test equality and index into a confusion matrix 1 year ago
  Peter Lane 0a28adfa65 updated to edition 2021 1 year ago
2 changed files with 76 additions and 6 deletions
  1. 2 2
      Cargo.toml
  2. 74 4
      src/lib.rs

+ 2 - 2
Cargo.toml

@@ -5,11 +5,11 @@ authors = ["Peter Lane <peterlane@gmx.com>"]
 categories = ["science"]
 description = "Confusion matrix implementation for storing results from a classification experiment and providing statistical information."
 repository = "https://notabug.org/peterlane/confusion-matrix-rust"
-edition = "2018"
+edition = "2021"
 exclude = ["examples/*.rs"]
 keywords = ["analysis", "data-science", "machine-learning"]
 license = "MIT"
 name = "confusion_matrix"
 readme = "about.md"
-version = "1.0.1"
+version = "1.1.0"
 

+ 74 - 4
src/lib.rs

@@ -9,6 +9,7 @@
 
 use std::collections::HashMap;
 use std::fmt;
+use std::ops;
 
 /// A confusion matrix is used to record pairs of (actual class, predicted class)
 /// as typically produced by a classification algorithm.
@@ -108,6 +109,7 @@ use std::fmt;
 ///   5  10   | pos
 /// ```
 /// 
+#[derive(Clone,Debug,Default,Eq,PartialEq)]
 pub struct ConfusionMatrix {
     matrix: HashMap<String, HashMap<String, usize>>,
 }
@@ -167,6 +169,40 @@ impl fmt::Display for ConfusionMatrix {
     }
 }
 
+impl ops::Index<(&str, &str)> for ConfusionMatrix {
+    type Output = usize;
+
+    /// Returns the count for an (actual, prediction) pair, or 0 if the pair 
+    /// is not known.
+    ///
+    /// * `actual` - the actual class of the instance, which we are hoping 
+    ///              the classifier will predict.
+    /// * `prediction` - the predicted class for the instance, as output from
+    ///                  the classifier.
+    /// 
+    /// # Example 
+    /// (using `cm` from [`Self::add_for()`])
+    ///
+    /// ```
+    /// # let mut cm = confusion_matrix::new();
+    /// # for _ in 0..2 { cm.add_for("positive", "positive"); }
+    /// # for _ in 0..5 { cm.add_for("positive", "negative"); }
+    /// # for _ in 0..1 { cm.add_for("negative", "positive"); }
+    /// # for _ in 0..3 { cm.add_for("negative", "negative"); }
+    /// assert_eq!(2, cm[("positive", "positive")]);
+    /// assert_eq!(0, cm[("positive", "not_known")]);
+    /// ```
+    ///
+    fn index(&self, (actual, prediction): (&str, &str)) -> &usize {
+        if let Some(predictions) = self.matrix.get(actual) {
+            if let Some(count) = predictions.get(prediction) {
+                return count;
+            }
+        }
+        &0
+    }
+}
+
 impl ConfusionMatrix {
     /// Adds one result to the matrix.
     /// 
@@ -681,10 +717,10 @@ mod tests {
         }
 
         assert_eq!(vec!["neg", "pos"], cm.labels());
-        assert_eq!(10, cm.count_for("pos", "pos"));
-        assert_eq!(5, cm.count_for("pos", "neg"));
-        assert_eq!(20, cm.count_for("neg", "neg"));
-        assert_eq!(5, cm.count_for("neg", "pos"));
+        assert_eq!(10, cm[("pos", "pos")]);
+        assert_eq!(5, cm[("pos", "neg")]);
+        assert_eq!(20, cm[("neg", "neg")]);
+        assert_eq!(5, cm[("neg", "pos")]);
 
         assert_eq!(40, cm.total());
         assert_eq!(10, cm.true_positive("pos"));
@@ -833,5 +869,39 @@ mod tests {
         assert_eq!(20, cm.false_positive("green"));
         assert_eq!(15, cm.true_negative("green"));
     }
+
+    #[test]
+    fn check_traits_clone_eq() {
+        let mut cm = new();
+        for _ in 0..3 {
+            cm.add_for("pos", "pos");
+        }
+        for _ in 0..2 {
+            cm.add_for("pos", "neg");
+        }
+        for _ in 0..4 {
+            cm.add_for("neg", "neg");
+        }
+        for _ in 0..5 {
+            cm.add_for("neg", "pos");
+        }
+        let mut cm_clone = cm.clone();
+
+        assert_eq!(cm.total(), cm_clone.total());
+        assert_eq!(14, cm.total());
+        assert_eq!(14, cm_clone.total());
+        assert_eq!(cm, cm_clone);
+
+        cm_clone.add_for("pos", "pos");
+        assert_eq!(14, cm.total());
+        assert_eq!(15, cm_clone.total());
+        assert!(cm != cm_clone);
+    }
+
+    #[test]
+    fn check_default() {
+        let cm = super::ConfusionMatrix::default();
+        assert_eq!(0, cm.total());
+    }
 }