PyTorch via NIRTorch#
PyTorch is a popular deep learning framework that many of the NIR-supported libraries are built on.
We have built the nirtorch
package to make it easier to develop PyTorch extensions for the NIR-supported libraries.
nirtorch
helps you write PyTorch code that (1) exports NIR models from PyTorch and (2) imports NIR models into PyTorch.
Exporting NIR models from PyTorch#
See also
Read more about exporting NIR models from PyTorch in the page about NIR Tracing with NIRTorch.
Exporting a NIR model requires two things: exporting the model’s nodes and edges.
Exporting edges#
Exporting edges is slightly complicated because PyTorch modules can have multiple inputs and outputs.
And because PyTorch modules are connected via function calls, which only happen at runtime.
Therefore, we need to trace the PyTorch module to get the edges with some sample input.
Luckily, nirtorch
package helps you do exactly that.
It works behind the scenes, but you can read more about it in To NIR: Tracing PyTorch.
Exporting nodes#
Exporting nodes in PyTorch is typically a 1:1 mapping between the PyTorch module and the NIR node.
This is done in nirtorch
by simply providing a function for each PyTorch module that returns the corresponding NIR node.
In Python types, this is a Dictionary[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]
.
import nir
import torch
class MyLeakyIntegrator(torch.nn.Module):
tau: torch.Tensor
r: torch.Tensor
v_leak: torch.Tensor
def __init__(self, tau, r, v_leak):
super().__init__() # Required for subclasses of torch.nn.Module
self.tau=tau
self.r=r
self.v_leak=v_leak
def forward(self, x, state):
if state is None:
state = torch.tensor([0.])
x = self.tau * (self.v_leak - state + self.r * x)
return x, x # Return both output and state
my_torch_dictionary = {
MyLeakyIntegrator: lambda module: nir.LI(tau=tau, r=r, v_leak=v_leak)
}
Why does the forward method have a state
parameter?
Read more about the structure of MyLeakyIntegrator
in the “Stateful execution” section below or in the page on State management in NIRTorch
The dictionary my_torch_dictionary
basically explains how to convert a custom MyLeakyIntegrator
module to a NIR LI (leaky integrator) node.
Note that we only have to add entries for nodes that we support and want to export.
If we do not support modules, we can leave them out.
Putting it all together#
We can now do a short, self-contained example for exporting a NIR node using nirtorch
.
Recall that the edges are traced automatically by the nirtorch
package, so the only thing we really have to define is the dictionary defined above, my_torch_dictionary
.
The rest is taken care of by nirtorch
’s torch_to_nir
function:
import nir, nirtorch, norse, torch
class MyLeakyIntegrator(torch.nn.Module):
tau: torch.Tensor
r: torch.Tensor
v_leak: torch.Tensor
def __init__(self, tau, r, v_leak):
super().__init__() # Required for subclasses of torch.nn.Module
self.tau=tau
self.r=r
self.v_leak=v_leak
my_torch_dictionary = {
MyLeakyIntegrator: lambda module: nir.LI(tau=module.tau, r=module.r, v_leak=module.v_leak)
}
# Create some mock data
tau, r, v_leak = torch.ones(3)
# ... And an example module
my_module = MyLeakyIntegrator(tau, r, v_leak)
# Use nirtorch to map my_module using my_torch_dictionary to convert modules
my_nir_graph = nirtorch.torch_to_nir(my_module, my_torch_dictionary)
my_nir_graph
LI(tau=tensor(1.), r=tensor(1.), v_leak=tensor(1.), input_type={'input': array([], dtype=float64)}, output_type={'output': array([], dtype=float64)}, metadata={})
We now have a NIR graph. You can inspect it by exploring the nodes and edges (see how to work with nodes and edges in Working with NIR) or send it to another platform for continued proccesing.
Importing NIR models into PyTorch#
See also
Read more about importing NIR models into PyTorch on the page about To PyTorch: Interpreting NIR.
Assuming you have a NIR graph in the Python object nir_graph
(see Usage), we need to inform nirtorch
how to map NIR nodes into your simulator.
That is, for each node, we need a function (nir.NIRNode -> torch.nn.Module
) which is given by a dictioary of type Dictionary[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]]
(read about why in To PyTorch: Interpreting NIR).
With that dictionary, we can call nirtorch
’s nir_to_torch
method with the NIR node we want to map.
Here’s a complete example where we are defining a simple mapper for the nir.LI
module that a MyLeakyIntegrator
module (also used above):
import nir, nirtorch, numpy as np, torch
class MyLeakyIntegrator(torch.nn.Module):
tau: torch.Tensor
r: torch.Tensor
v_leak: torch.Tensor
def __init__(self, tau, r, v_leak):
super().__init__() # Required for subclasses of torch.nn.Module
self.tau=tau
self.r=r
self.v_leak=v_leak
def forward(self, x, state):
if state is None:
state = torch.tensor([0.])
x = self.tau * (self.v_leak - state + self.r * x)
return x, x # Return both output and state
my_nir_dictionary = {
nir.LI: lambda node: MyLeakyIntegrator(torch.from_numpy(node.tau), torch.from_numpy(node.r), torch.from_numpy(node.v_leak))
}
tau = np.ones(1)
r = np.ones(1)
v_leak = np.ones(1)
my_nir_graph = nir.NIRGraph.from_list(nir.LI(tau, r, v_leak))
my_torch_module = nirtorch.nir_to_torch(my_nir_graph, my_nir_dictionary)
# I can now execute the torch module
output, state = my_torch_module(torch.rand(10))
output
tensor([1.8656, 1.9804, 1.8317, 1.0708, 1.9365, 1.4059, 1.6528, 1.6479, 1.3619,
1.0023], dtype=torch.float64)
Stateful execution#
See also
Read more about state handling with NIRTorch on the page about State handling with NIRTorch.
Note the stateful execution above, both in the MyLeakyIntegrator
and in the second parameter in the call to my_torch_module
!
Many NIR primitives can be seen as recurrent neurons, which require us to maintain state.
That can be done either implicitly or explicitly.
Implicit state handling means setting some variable in the module that automatically gets updated so the user does not have to worry about it.
The downside is that the user does not have any control over it and may forget to reset the state. The worst-case is that the module behaves wrongly without the user noticing.
Explicit state handling requires that state is both sent as input and returned as output. Typically, this means that the module requires two inputs (data + state) and returns a tuple of (data, state)
.
This grants complete control to the user with the downside that the user has to handle the state.
Since some PyTorch libraries explicitly declare state, nirtorch
uses the explicit state handling method.
Specifically, the state is a dictionary where each entry correspond to the state of each submodule.
The state may contain multiple levels, if the module has submodules etc. Read more about the distinction between implicit and explicit state, as well as how this is handled in nirtorch
in the page on State management in NIRTorch.