Source code for draugr.tensorboard_utilities.experimental.pretty_cf

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 20-04-2021
           """

import itertools
from typing import Sequence

import numpy
from matplotlib import pyplot
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

__all__ = ["pretty_print_conf_matrix"]


[docs]def pretty_print_conf_matrix( y_true: Sequence, y_pred: Sequence, classes: Sequence, normalize: bool = False, title: str = "Confusion matrix", cmap=pyplot.cm.Blues, ) -> None: """ Mostly stolen from: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py Normalization changed, classification_report stats added below plot """ cm = confusion_matrix(y_true, y_pred) # Configure Confusion Matrix Plot Aesthetics (no text yet) pyplot.imshow(cm, interpolation="nearest", cmap=cmap) pyplot.title(title, fontsize=14) tick_marks = numpy.arange(len(classes)) pyplot.xticks(tick_marks, classes, rotation=45) pyplot.yticks(tick_marks, classes) pyplot.ylabel("True label", fontsize=12) pyplot.xlabel("Predicted label", fontsize=12) # Calculate normalized values (so all cells sum to 1) if desired if normalize: cm = numpy.round(cm.astype("float") / cm.sum(), 2) # (axis=1)[:, numpy.newaxis] # Place Numbers as Text on Confusion Matrix Plot thresh = cm.max() / 2.0 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): pyplot.text( j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black", fontsize=12, ) # Add Precision, Recall, F-1 Score as Captions Below Plot rpt = classification_report(y_true, y_pred) rpt = rpt.replace("avg / total", " avg") rpt = rpt.replace("support", "N Obs") pyplot.annotate( rpt, xy=(0, 0), xytext=(-50, -140), xycoords="axes fraction", textcoords="offset points", fontsize=12, ha="left", ) # Plot pyplot.tight_layout()
if __name__ == "__main__": from sklearn import datasets from sklearn.svm import SVC # get data, make predictions (X, y) = datasets.load_iris(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5) clf = SVC() clf.fit(X_train, y_train) y_test_pred = clf.predict(X_test) # Plot Confusion Matrix pyplot.style.use("classic") pyplot.figure(figsize=(3, 3)) pretty_print_conf_matrix( y_test, y_test_pred, classes=["0", "1", "2"], normalize=True, title="Confusion Matrix", ) pyplot.show()