bad_grad.py 526 B

1234567891011121314151617181920212223
  1. # To add a new cell, type '# %%'
  2. # To add a new markdown cell, type '# %% [markdown]'
  3. # %% [markdown]
  4. # # Bad gradients in PyTorch graph
  5. # %%
  6. import torch
  7. from draugr.torch_utilities import register_bad_grad_hooks
  8. x = torch.randn(10, 10, requires_grad=True)
  9. y = torch.randn(10, 10, requires_grad=True)
  10. z = x / (y * 0)
  11. z = z.sum() * 2
  12. get_dot = register_bad_grad_hooks(z)
  13. z.backward()
  14. dot = get_dot()
  15. # dot.save('tmp.dot') # to get .dot
  16. # dot.render('tmp') # to get SVG
  17. dot # in Jupyter, you can just render the variable