Source code for draugr.torch_utilities.optimisation.debugging.gradients.guided
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 14-02-2021
"""
import numpy
import torch
from torch.autograd import Function
__all__ = ["GuidedBackPropReLUModel", "GuidedBackPropReLU"]
[docs]class GuidedBackPropReLU(Function):
[docs] @staticmethod
def forward(self, input_img):
"""
:param self:
:param input_img:
:return:
"""
positive_mask = (input_img > 0).type_as(input_img)
output = torch.addcmul(
torch.zeros(input_img.size()).type_as(input_img), input_img, positive_mask
)
self.save_for_backward(input_img, output)
return output
[docs] @staticmethod
def backward(self, grad_output):
"""
:param self:
:param grad_output:
:return:
"""
input_img, output = self.saved_tensors
positive_mask_1 = (input_img > 0).type_as(grad_output)
positive_mask_2 = (grad_output > 0).type_as(grad_output)
grad_input = torch.addcmul(
torch.zeros(input_img.size()).type_as(input_img),
torch.addcmul(
torch.zeros(input_img.size()).type_as(input_img),
grad_output,
positive_mask_1,
),
positive_mask_2,
)
return grad_input
[docs]class GuidedBackPropReLUModel:
"""description"""
[docs] def __init__(self, model, use_cuda):
self._model = model
self._model.eval()
self._use_cuda = use_cuda
if self._use_cuda:
self._model = self._model.cuda()
def recursive_relu_apply(module_top):
"""
:param module_top:
"""
for idx, module in module_top._modules.items():
recursive_relu_apply(module)
if module.__class__.__name__ == "ReLU":
module_top._modules[idx] = GuidedBackPropReLU.apply
recursive_relu_apply(self._model) # replace ReLU with GuidedBackpropReLU
[docs] def forward(self, input_img):
"""
:param input_img:
:return:
"""
return self._model(input_img)
def __call__(self, input_img, target_category=None):
if self._use_cuda:
input_img = input_img.cuda()
input_img = input_img.requires_grad_(True)
output = self.forward(input_img)
if target_category is None:
target_category = numpy.argmax(output.cpu().data.numpy())
one_hot = numpy.zeros((1, output.size()[-1]), dtype=numpy.float32)
one_hot[0][target_category] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
if self._use_cuda:
one_hot = one_hot.cuda()
one_hot = torch.sum(one_hot * output)
one_hot.backward(retain_graph=True)
output = input_img.grad.cpu().data.numpy()
output = output[0, :, :, :]
return output