|
@@ -76,18 +76,11 @@ use std::ops;
|
|
|
/// fn main() {
|
|
|
/// let mut cm = confusion_matrix::new();
|
|
|
///
|
|
|
-/// for _ in 0..10 {
|
|
|
-/// cm.add_for("pos", "pos");
|
|
|
-/// }
|
|
|
-/// for _ in 0..3 {
|
|
|
-/// cm.add_for("pos", "neg");
|
|
|
-/// }
|
|
|
-/// for _ in 0..20 {
|
|
|
-/// cm.add_for("neg", "neg");
|
|
|
-/// }
|
|
|
-/// for _ in 0..5 {
|
|
|
-/// cm.add_for("neg", "pos");
|
|
|
-/// }
|
|
|
+/// cm[("pos", "pos")] = 10;
|
|
|
+/// cm[("pos", "neg")] = 3;
|
|
|
+/// cm[("neg", "neg")] = 20;
|
|
|
+/// cm[("neg", "pos")] = 5;
|
|
|
+///
|
|
|
/// println!("Precision: {}", cm.precision("pos"));
|
|
|
/// println!("Recall: {}", cm.recall("pos"));
|
|
|
/// println!("MCC: {}", cm.matthews_correlation("pos"));
|
|
@@ -203,6 +196,45 @@ impl ops::Index<(&str, &str)> for ConfusionMatrix {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+impl ops::IndexMut<(&str, &str)> for ConfusionMatrix {
|
|
|
+ /// Provides a mutable reference to the count for an (actual, prediction) pair.
|
|
|
+ ///
|
|
|
+ /// * `actual` - the actual class of the instance, which we are hoping
|
|
|
+ /// the classifier will predict.
|
|
|
+ /// * `prediction` - the predicted class for the instance, as output from
|
|
|
+ /// the classifier.
|
|
|
+ ///
|
|
|
+ /// # Example
|
|
|
+ ///
|
|
|
+ /// ```
|
|
|
+ /// let mut cm = confusion_matrix::new();
|
|
|
+ /// for _ in 0..2 { cm[("positive", "positive")] += 1; }
|
|
|
+ /// cm[("positive", "negative")] = 5;
|
|
|
+ /// cm[("negative", "positive")] = 1;
|
|
|
+ /// for _ in 0..3 { cm[("negative", "negative")] += 1; }
|
|
|
+ /// assert_eq!(2, cm[("positive", "positive")]);
|
|
|
+ /// assert_eq!(5, cm[("positive", "negative")]);
|
|
|
+ /// assert_eq!(0, cm[("positive", "not_known")]);
|
|
|
+ /// ```
|
|
|
+ ///
|
|
|
+ fn index_mut(&mut self, (actual, prediction): (&str, &str)) -> &mut usize {
|
|
|
+ // make sure there is a slot for (actual, prediction)
|
|
|
+ if !self.matrix.contains_key(actual) {
|
|
|
+ self.matrix.insert(String::from(actual), HashMap::new());
|
|
|
+ }
|
|
|
+ if let Some(predictions) = self.matrix.get_mut(actual) {
|
|
|
+ if None == predictions.get(prediction) {
|
|
|
+ predictions.insert(String::from(prediction), 0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // return a mutable reference to (actual, prediction) slot
|
|
|
+ self.matrix.get_mut(actual)
|
|
|
+ .expect("Confusion matrix must contain actual value")
|
|
|
+ .get_mut(prediction)
|
|
|
+ .expect("Confusion matrix must contain predicted value")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
impl ConfusionMatrix {
|
|
|
/// Adds one result to the matrix.
|
|
|
///
|
|
@@ -707,14 +739,10 @@ mod tests {
|
|
|
cm.add_for("pos", "pos");
|
|
|
}
|
|
|
for _ in 0..5 {
|
|
|
- cm.add_for("pos", "neg");
|
|
|
- }
|
|
|
- for _ in 0..20 {
|
|
|
- cm.add_for("neg", "neg");
|
|
|
- }
|
|
|
- for _ in 0..5 {
|
|
|
- cm.add_for("neg", "pos");
|
|
|
+ cm[("pos", "neg")] += 1;
|
|
|
}
|
|
|
+ cm[("neg", "neg")] = 20;
|
|
|
+ cm[("neg", "pos")] = 5;
|
|
|
|
|
|
assert_eq!(vec!["neg", "pos"], cm.labels());
|
|
|
assert_eq!(10, cm[("pos", "pos")]);
|
|
@@ -753,18 +781,10 @@ mod tests {
|
|
|
#[test]
|
|
|
fn test_two_classes_2() {
|
|
|
let mut cm = new();
|
|
|
- for _ in 0..5 {
|
|
|
- cm.add_for("pos", "pos")
|
|
|
- }
|
|
|
- for _ in 0..1 {
|
|
|
- cm.add_for("pos", "neg")
|
|
|
- }
|
|
|
- for _ in 0..3 {
|
|
|
- cm.add_for("neg", "neg")
|
|
|
- }
|
|
|
- for _ in 0..2 {
|
|
|
- cm.add_for("neg", "pos")
|
|
|
- }
|
|
|
+ cm[("pos", "pos")] = 5;
|
|
|
+ cm[("pos", "neg")] = 1;
|
|
|
+ cm[("neg", "neg")] = 3;
|
|
|
+ cm[("neg", "pos")] = 2;
|
|
|
|
|
|
assert_eq!(11, cm.total());
|
|
|
assert_eq!(5, cm.true_positive("pos"));
|
|
@@ -796,18 +816,10 @@ mod tests {
|
|
|
i: f64,
|
|
|
) {
|
|
|
let mut cm = new();
|
|
|
- for _ in 0..a {
|
|
|
- cm.add_for("pos", "pos");
|
|
|
- }
|
|
|
- for _ in 0..b {
|
|
|
- cm.add_for("pos", "neg");
|
|
|
- }
|
|
|
- for _ in 0..c {
|
|
|
- cm.add_for("neg", "neg");
|
|
|
- }
|
|
|
- for _ in 0..d {
|
|
|
- cm.add_for("neg", "pos");
|
|
|
- }
|
|
|
+ cm[("pos", "pos")] = a;
|
|
|
+ cm[("pos", "neg")] = b;
|
|
|
+ cm[("neg", "neg")] = c;
|
|
|
+ cm[("neg", "pos")] = d;
|
|
|
|
|
|
test_approx_same(e, cm.matthews_correlation("pos"));
|
|
|
test_approx_same(f, cm.precision("pos"));
|
|
@@ -873,18 +885,10 @@ mod tests {
|
|
|
#[test]
|
|
|
fn check_traits_clone_eq() {
|
|
|
let mut cm = new();
|
|
|
- for _ in 0..3 {
|
|
|
- cm.add_for("pos", "pos");
|
|
|
- }
|
|
|
- for _ in 0..2 {
|
|
|
- cm.add_for("pos", "neg");
|
|
|
- }
|
|
|
- for _ in 0..4 {
|
|
|
- cm.add_for("neg", "neg");
|
|
|
- }
|
|
|
- for _ in 0..5 {
|
|
|
- cm.add_for("neg", "pos");
|
|
|
- }
|
|
|
+ cm[("pos", "pos")] = 3;
|
|
|
+ cm[("pos", "neg")] = 2;
|
|
|
+ cm[("neg", "neg")] = 4;
|
|
|
+ cm[("neg", "pos")] = 5;
|
|
|
let mut cm_clone = cm.clone();
|
|
|
|
|
|
assert_eq!(cm.total(), cm_clone.total());
|