2 Commits 66557da109 ... 41d9ffb8fa

Author SHA1 Message Date
  Peter Lane 41d9ffb8fa updated documentation 1 year ago
  Peter Lane a3313bdb2a added IndexMut implementation 1 year ago
2 changed files with 75 additions and 55 deletions
  1. 16 0
      examples/demo.rs
  2. 59 55
      src/lib.rs

+ 16 - 0
examples/demo.rs

@@ -0,0 +1,16 @@
+use confusion_matrix;
+
+fn main() {
+    let mut cm = confusion_matrix::new();
+
+    cm[("pos", "pos")] = 10;
+    cm[("pos", "neg")] = 3;
+    cm[("neg", "neg")] = 20;
+    cm[("neg", "pos")] = 5;
+    
+    println!("Precision: {}", cm.precision("pos"));
+    println!("Recall: {}", cm.recall("pos"));
+    println!("MCC: {}", cm.matthews_correlation("pos"));
+    println!("");
+    println!("{}", cm);
+}

+ 59 - 55
src/lib.rs

@@ -76,18 +76,11 @@ use std::ops;
 /// fn main() {
 ///     let mut cm = confusion_matrix::new();
 /// 
-///     for _ in 0..10 {
-///         cm.add_for("pos", "pos");
-///     }
-///     for _ in 0..3 {
-///         cm.add_for("pos", "neg");
-///     }
-///     for _ in 0..20 {
-///         cm.add_for("neg", "neg");
-///     }
-///     for _ in 0..5 {
-///         cm.add_for("neg", "pos");
-///     }
+///     cm[("pos", "pos")] = 10;
+///     cm[("pos", "neg")] = 3;
+///     cm[("neg", "neg")] = 20;
+///     cm[("neg", "pos")] = 5;
+///     
 ///     println!("Precision: {}", cm.precision("pos"));
 ///     println!("Recall: {}", cm.recall("pos"));
 ///     println!("MCC: {}", cm.matthews_correlation("pos"));
@@ -203,6 +196,45 @@ impl ops::Index<(&str, &str)> for ConfusionMatrix {
     }
 }
 
+impl ops::IndexMut<(&str, &str)> for ConfusionMatrix {
+    /// Provides a mutable reference to the count for an (actual, prediction) pair.
+    ///
+    /// * `actual` - the actual class of the instance, which we are hoping 
+    ///              the classifier will predict.
+    /// * `prediction` - the predicted class for the instance, as output from
+    ///                  the classifier.
+    /// 
+    /// # Example 
+    ///
+    /// ```
+    /// let mut cm = confusion_matrix::new();
+    /// for _ in 0..2 { cm[("positive", "positive")] += 1; }
+    /// cm[("positive", "negative")] = 5;
+    /// cm[("negative", "positive")] = 1;
+    /// for _ in 0..3 { cm[("negative", "negative")] += 1; }
+    /// assert_eq!(2, cm[("positive", "positive")]);
+    /// assert_eq!(5, cm[("positive", "negative")]);
+    /// assert_eq!(0, cm[("positive", "not_known")]);
+    /// ```
+    ///
+    fn index_mut(&mut self, (actual, prediction): (&str, &str)) -> &mut usize {
+        // make sure there is a slot for (actual, prediction)
+        if !self.matrix.contains_key(actual) {
+            self.matrix.insert(String::from(actual), HashMap::new());
+        }
+        if let Some(predictions) = self.matrix.get_mut(actual) {
+            if None == predictions.get(prediction) {
+                predictions.insert(String::from(prediction), 0);
+            }
+        }
+        // return a mutable reference to (actual, prediction) slot
+        self.matrix.get_mut(actual)
+            .expect("Confusion matrix must contain actual value")
+            .get_mut(prediction)
+            .expect("Confusion matrix must contain predicted value")
+    }
+}
+
 impl ConfusionMatrix {
     /// Adds one result to the matrix.
     /// 
@@ -707,14 +739,10 @@ mod tests {
             cm.add_for("pos", "pos");
         }
         for _ in 0..5 {
-            cm.add_for("pos", "neg");
-        }
-        for _ in 0..20 {
-            cm.add_for("neg", "neg");
-        }
-        for _ in 0..5 {
-            cm.add_for("neg", "pos");
+            cm[("pos", "neg")] += 1;
         }
+        cm[("neg", "neg")] = 20;
+        cm[("neg", "pos")] = 5;
 
         assert_eq!(vec!["neg", "pos"], cm.labels());
         assert_eq!(10, cm[("pos", "pos")]);
@@ -753,18 +781,10 @@ mod tests {
     #[test]
     fn test_two_classes_2() {
         let mut cm = new();
-        for _ in 0..5 {
-            cm.add_for("pos", "pos")
-        }
-        for _ in 0..1 {
-            cm.add_for("pos", "neg")
-        }
-        for _ in 0..3 {
-            cm.add_for("neg", "neg")
-        }
-        for _ in 0..2 {
-            cm.add_for("neg", "pos")
-        }
+        cm[("pos", "pos")] = 5;
+        cm[("pos", "neg")] = 1;
+        cm[("neg", "neg")] = 3;
+        cm[("neg", "pos")] = 2;
 
         assert_eq!(11, cm.total());
         assert_eq!(5, cm.true_positive("pos"));
@@ -796,18 +816,10 @@ mod tests {
         i: f64,
         ) {
         let mut cm = new();
-        for _ in 0..a {
-            cm.add_for("pos", "pos");
-        }
-        for _ in 0..b {
-            cm.add_for("pos", "neg");
-        }
-        for _ in 0..c {
-            cm.add_for("neg", "neg");
-        }
-        for _ in 0..d {
-            cm.add_for("neg", "pos");
-        }
+        cm[("pos", "pos")] = a;
+        cm[("pos", "neg")] = b;
+        cm[("neg", "neg")] = c;
+        cm[("neg", "pos")] = d;
 
         test_approx_same(e, cm.matthews_correlation("pos"));
         test_approx_same(f, cm.precision("pos"));
@@ -873,18 +885,10 @@ mod tests {
     #[test]
     fn check_traits_clone_eq() {
         let mut cm = new();
-        for _ in 0..3 {
-            cm.add_for("pos", "pos");
-        }
-        for _ in 0..2 {
-            cm.add_for("pos", "neg");
-        }
-        for _ in 0..4 {
-            cm.add_for("neg", "neg");
-        }
-        for _ in 0..5 {
-            cm.add_for("neg", "pos");
-        }
+        cm[("pos", "pos")] = 3;
+        cm[("pos", "neg")] = 2;
+        cm[("neg", "neg")] = 4;
+        cm[("neg", "pos")] = 5;
         let mut cm_clone = cm.clone();
 
         assert_eq!(cm.total(), cm_clone.total());