Source code for draugr.torch_utilities.optimisation.debugging.layer_fetching
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 07/07/2020
"""
import functools
from collections import OrderedDict
from typing import Tuple
import torch
from torch import nn
__all__ = ["IntermediateLayerGetter"]
[docs]class IntermediateLayerGetter:
"""description"""
[docs] def __init__(
self,
model: torch.nn.Module,
return_layers: dict = None,
):
"""
Wraps a Pytorch module to get intermediate values, eg for getting intermediate activations
Arguments:
model {nn.module} -- The Pytorch module to call
return_layers {dict} -- Dictionary with the selected submodules
to return the output (format: {[current_module_name]: [desired_output_name]},
current_module_name can be a nested submodule, e.g. submodule1.submodule2.submodule3)
Returns:
(mid_outputs {OrderedDict}, model_output {any}) -- mid_outputs keys are
your desired_output_name (s) and their values are the returned tensors
of those submodules (OrderedDict([(desired_output_name,tensor(...)), ...).
In case a submodule is called more than one time, all it's outputs are
stored in a list."""
self._model = model
if return_layers:
self.return_layers = return_layers.items()
else:
self.return_layers = {k: k for k, v in model.named_modules()}.items()
[docs] @staticmethod
def reduce_getattr(obj, attr, *args):
"""
# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427
:param obj:
:type obj:
:param attr:
:type attr:
:param args:
:type args:
:return:
:rtype:"""
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, (obj, *attr.split(".")))
def __call__(self, *args, **kwargs) -> Tuple:
ret = OrderedDict()
handles = []
for name, new_name in self.return_layers:
if name == "":
continue # TODO: Fail maybe?
layer = IntermediateLayerGetter.reduce_getattr(self._model, name)
if isinstance(layer, torch.nn.Module): # Should be a torch module!
def hook(
module,
i,
output,
*,
new_name_=new_name, # Hack for new func, otherwise func is overriden. # BUG?
):
"""
:param new_name_:
:param module:
:type module:
:param i:
:type i:
:param output:
:type output:"""
if new_name_ in ret:
cur_val = ret[new_name_]
if type(cur_val) is list:
ret[new_name_].append(output)
else:
ret[new_name_] = [cur_val, output]
else:
ret[new_name_] = output
try:
h = layer.register_forward_hook(hook)
handles.append(h)
except AttributeError as e:
raise AttributeError(f"Module {name} not found")
else:
raise AttributeError(
f"Requested module activation with {name} was not a module but {type(layer)}"
)
output = self._model(*args, **kwargs)
for h in handles:
h.remove()
return ret, output
if __name__ == "__main__":
def adsad() -> None:
"""
:rtype: None
"""
class Model(nn.Module):
"""description"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.fc2 = nn.Linear(2, 2)
self.nested = nn.Sequential(
nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 3)), nn.Linear(3, 1)
)
self.interaction_idty = (
nn.Identity()
) # Simple trick for operations not performed as modules
def forward(self, x):
"""
:param x:
:type x:
:return:
:rtype:"""
x1 = self.fc1(x)
x2 = self.fc2(x)
interaction = x1 * x2
self.interaction_idty(interaction)
return self.nested(interaction)
model = Model()
return_layers = {
"fc2": "fc2",
"nested.0.1": "nested",
"interaction_idty": "interaction",
}
mid_getter = IntermediateLayerGetter(model, return_layers=return_layers)
mid_outputs, model_output = mid_getter(torch.randn(1, 2))
print(model_output)
print(mid_outputs)
def adsad2() -> None:
"""
:rtype: None
"""
class Model(nn.Module):
"""description"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.fc2 = nn.Linear(2, 2)
self.nested = nn.Sequential(
nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 3)), nn.Linear(3, 1)
)
self.interaction_idty = (
nn.Identity()
) # Simple trick for operations not performed as modules
def forward(self, x):
"""
:param x:
:type x:
:return:
:rtype:"""
x1 = self.fc1(x)
x2 = self.fc2(x)
interaction = x1 * x2
self.interaction_idty(interaction)
return self.nested(interaction)
model = Model()
mid_getter = IntermediateLayerGetter(model)
mid_outputs, model_output = mid_getter(torch.randn(1, 2))
print(model_output)
print(mid_outputs)
adsad()
# adsad2()