Source code for draugr.torch_utilities.sessions.jit_sessions
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 10/06/2020
"""
import torch
from torch import jit
from warg import AlsoDecorator
__all__ = ["TorchJitSession", "TorchIgnoreJitSession"]
[docs]class TorchIgnoreJitSession(AlsoDecorator):
"""
# Disable torch jit tracing"""
[docs] def __init__(self, no_side_effect: bool = True):
self._no_side_effect = no_side_effect
if no_side_effect:
self.prev_state = jit._enabled
def __enter__(self):
jit._enabled = False
return True
def __exit__(self, exc_type, exc_val, exc_tb):
if self._no_side_effect:
jit._enabled = self.prev_state
else:
jit._enabled = True
[docs]class TorchJitSession(AlsoDecorator):
"""
# Disable torch jit tracing"""
[docs] def __init__(self, enabled=False, no_side_effect: bool = True):
self._no_side_effect = no_side_effect
self._effect = enabled
if no_side_effect:
self.prev_state = jit._enabled
def __enter__(self):
jit._enabled = self._effect
return True
def __exit__(self, exc_type, exc_val, exc_tb):
if self._no_side_effect:
jit._enabled = self.prev_state
else:
jit._enabled = True
if __name__ == "__main__":
def a() -> None:
"""
:rtype: None
"""
@torch.jit.script
def scripted_fn(x: torch.Tensor):
"""description"""
for i in range(12):
x = x + x
return x
def fn(x):
"""description"""
x = torch.neg(x)
# import pdb
# pdb.set_trace()
return scripted_fn(x)
traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
print(type(traced_fn)) # torch.jit.ScriptFuntcion
if isinstance(traced_fn, torch.jit.ScriptFunction):
# See the compiled graph as Python code
print(traced_fn.code)
with TorchIgnoreJitSession():
a()
a()