Source code for draugr.torch_utilities.optimisation.stopping.overfitting
import contextlib
from warg import sink, Number
__all__ = ["ImprovementDetector", "OverfitDetector"]
[docs]class ImprovementDetector(contextlib.AbstractContextManager):
"""description"""
[docs] def __init__(
self,
patience: int,
writer: callable = print,
minimization: bool = True,
callback: callable = None,
):
"""
NOTE: strictly greater or less than is considered as improvement
:param patience:
:type patience:
:param writer:
:type writer:
:param minimization:
:type minimization:
"""
self._patience = patience
self._writer = writer
self._count = 0
self._minimization = minimization # as opposed to maximization
self._best_value = None
self._callback = (lambda: True) if callback is None else callback
self._best_idx = None
def __call__(self, value: Number) -> bool:
if self._best_value is None:
self._best_value = value
return True
if self._minimization:
if value >= self._best_value:
self._count += 1
if self._verbose:
self._writer(
f"No improvement since last update: {value}>{self._best_value}"
)
else:
self._count = 0
self._best_value = value
else:
if value <= self._best_value:
self._count += 1
else:
self._count = 0
self._best_value = value
if self._count >= self._patience:
self._writer(f"No improvement detected, patience reached")
return False
else:
return self._callback()
[docs] def reset(self) -> None:
"""description"""
self._count = 0
self._best_value = None
self._writer(f"Improvement detector reset")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return True
[docs]class OverfitDetector(contextlib.AbstractContextManager):
"""description"""
[docs] def __init__(
self,
patience: int,
writer: callable = print,
minimization: bool = True,
callback: callable = None,
verbose: bool = False,
):
"""
NOTE: equality, greater or less than
:param patience:
:type patience:
:param writer:
:type writer:
:param minimization:
:type minimization:
"""
self._patience = patience
self._writer = writer
self._count = 0
self._minimization = minimization # as opposed to maximization
self._best_value = None
self._callback = sink if callback is None else callback
self._verbose = verbose
def __call__(self, value: Number) -> bool:
if self._best_value is None:
self._best_value = value
return False
if self._minimization:
if value > self._best_value:
self._count += 1
if self._verbose:
self._writer(f"Worse than last update: {value}>{self._best_value}")
else:
self._count = 0
self._best_value = value
else:
if value < self._best_value:
self._count += 1
if self._verbose:
self._writer(f"Worse than last update: {value}<{self._best_value}")
else:
self._count = 0
self._best_value = value
if self._count >= self._patience:
self._writer(
f"Overfit detected, patience reached {value}>{self._best_value} with patience {self._patience}"
)
return True
else:
self._callback()
return False
[docs] def reset(self) -> None:
"""description"""
self._count = 0
self._best_value = None
self._writer(f"Overfit detector reset")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return True
if __name__ == "__main__":
with OverfitDetector(patience=3, writer=print) as is_overfitting:
for i in range(10):
is_overfitting(i)
is_overfitting.reset()
for i in range(10):
is_overfitting(-i)
is_overfitting.reset()
print("start training")
for i in range(2):
is_overfitting(i)
is_overfitting(0)
print("start overfitting")
for i in range(10):
print(is_overfitting(i), i)
print("\n\n")
with ImprovementDetector(patience=3, writer=print) as is_improving:
for i in range(10):
is_improving(i)
is_improving.reset()
for i in range(10):
is_improving(-i)
is_improving.reset()
print("start training")
for i in range(2):
is_improving(i)
is_improving(0)
print("start not improving")
for i in range(10):
print(is_improving(i), i)