Source code for draugr.torch_utilities.evaluation.classification
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 02-12-2020
"""
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from draugr.torch_utilities.sessions.model_sessions import TorchEvalSession
from draugr.torch_utilities.system.device import global_torch_device
from warg import kws_sink
__all__ = ["find_n_misclassified"]
[docs]def find_n_misclassified(
model: torch.nn.Module,
evaluation_loader: DataLoader,
*,
mapper: callable = kws_sink,
n: int = 10,
device: torch.device = global_torch_device(),
) -> None:
"""description"""
j = 0
num_samples = len(evaluation_loader)
with TorchEvalSession(model):
for i, (waveform, target) in tqdm(enumerate(evaluation_loader), total=n):
output = mapper(model(waveform.to(device)).argmax(dim=-1).squeeze())
truth = mapper(target)
if output != truth:
print(
f"Data point #{i}/{num_samples}. Expected: {truth}. Predicted: {output}."
)
j += 1
if j >= n:
break
else:
print("All examples in this dataset were correctly classified!")
print("In this case, let's just look at the last data point")
print(f"Data point #{i}. Expected: {truth}. Predicted: {output}.")