Source code for draugr.torch_utilities.optimisation.debugging.gradients.flow

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 29/07/2020
           """

__all__ = ["plot_grad_flow"]

import numpy
import torch
from matplotlib import pyplot
from matplotlib.lines import Line2D

from draugr.torch_utilities.optimisation.parameters import normal_init_weights


[docs]def plot_grad_flow( model: torch.nn.Module, lines: bool = True, alpha: float = 0.5, line_width: float = 1.0, ) -> None: """ Plots the gradients flowing through different layers in the net during training. Can be used for checking for possible gradient vanishing / exploding problems. Usage: After value.backwards(), use plot_grad_flow(model) to visualize the gradient flow of model :param model: :type model: :param lines: :type lines: :param alpha: :type alpha: :param line_width: :type line_width:""" assert 0.0 < alpha <= 1.0 ave_grads = [] max_grads = [] layers = [] for n, p in model.named_parameters(): if p.requires_grad and ("bias" not in n): layers.append(n) grad_abs = p.grad.abs() ave_grads.append(grad_abs.mean()) max_grads.append(grad_abs.max()) if lines: pyplot.plot(max_grads, alpha=alpha, linewidth=line_width, color="r") pyplot.plot(ave_grads, alpha=alpha, linewidth=line_width, color="g") else: pyplot.bar( numpy.arange(len(max_grads)), max_grads, alpha=alpha, linewidth=line_width, color="r", ) pyplot.bar( numpy.arange(len(max_grads)), ave_grads, alpha=alpha, linewidth=line_width, color="g", ) pyplot.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k") pyplot.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") pyplot.xlim(left=0, right=len(ave_grads)) max_g = max(max_grads) margin = max_g * 1.1 pyplot.ylim( bottom=max_g - margin, top=margin ) # zoom in on the lower gradient regions pyplot.xlabel("Layers") pyplot.ylabel("Gradient Magnitude") pyplot.title("Gradient Flow") pyplot.grid(True) pyplot.legend( [ Line2D([0], [0], color="c", lw=4), Line2D([0], [0], color="b", lw=4), Line2D([0], [0], color="k", lw=4), ], ["max-gradient", "mean-gradient", "zero-gradient"], )
if __name__ == "__main__": def a() -> None: """ :rtype: None """ i = torch.randn(10, 50, requires_grad=True) target = torch.empty(10, dtype=torch.long).random_(2) model = torch.nn.Sequential( torch.nn.Linear(50, 50), torch.nn.ReLU(), torch.nn.Linear(50, 50), torch.nn.ReLU(), torch.nn.Linear(50, 50), torch.nn.ReLU(), torch.nn.Linear(50, 50), torch.nn.ReLU(), torch.nn.Linear(50, 2), ) normal_init_weights(model, std=1.2) criterion = torch.nn.CrossEntropyLoss() outputs = model(i) loss = criterion(outputs, target) loss.backward() plot_grad_flow(model) pyplot.show() a()