To NIR: Tracing PyTorch#
When creating NIR nodes from PyTorch, we go through the PyTorch modules to create a graph structure that we can populate with NIR nodes. This “going through” is what we refer to as “tracing” because we have to track the path of signals through potentially complex modules structure, so we know where to put input signals and read output signals.
The tracing happens in two steps
First, we find all the PyTorch Modules that should be mapped to NIR nodes
Second, we trace through the PyTorch modules to find the edges for the NIR graph.
We use the symbolic tracing from torch.fx
to go through the graphs, because it’s fast and it allows us to reconstruct NIR Graphs without executing any code.
Mapping nodes#
For the first step, we need to know which nodes can be mapped (the keys) and how they should be mapped (a function that maps torch.nn.Module
s to nir.NIRNode
s).
To that end, nirtorch
expects a dictionary of type Dictionary[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]
.
That is, a dictionary with torch.nn.Module
s as keys and functions that map torch.nn.Module -> nir.NIRNode
.
One may wonder why we don’t just use a single function to convert torch.nn.Module
to nir.NIRNode
s instead of a dictionary, but the keys in the dictionary are vital in understanding which modules to map.
We call the keys in the module_map
leaf nodes because they are not processed further: if they are included in the module_map
dictionary, the corresponding mapping function (Callable[[torch.nn.Module], nir.NIRNode]
) needs to deal with any potential submodules.
Conversely, if a module is not in the dictionary (such as a torch.nn.ModuleList
), we have to traverse the modules inside that module.
Tracing edges#
Since edges in NIR does not have any logic, finding the edges in the graph is purely a matter of creating the input-output relations.
Tracing edges is done by going through the calls in the torch.nn.Module
s and finding the source node (input signal) for the call and map that to the node that represents the module call.
Provided that all modules are well defined, this step is relatively straigth-forward, although there are some complications regarding pure function calls (as opposed to module calls), we cover below in Mapping function calls.
Tracing in practice#
In practice, the behavior above is implemented in nirtorch.torch_to_nir
with the following signature:
def torch_to_nir(
module: torch.nn.Module,
module_map: Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]],
default_dict: Optional[
Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]
] = None,
) -> nir.NIRGraph: ...
Here is a short, self-contained example on how you to map a torch.nn.AvgPool2d
to nir.AvgPool2d
:
import nir, nirtorch, numpy as np, torch
# First, we describe the PyTorch module we want to convert
torch_module = torch.nn.AvgPool2d(kernel_size=(2, 2), stride=0, padding=1)
# Second, we define the dictionary
torch_to_nir_map = {
torch.nn.AvgPool2d: lambda module: nir.AvgPool2d(
kernel_size=np.array(module.kernel_size),
stride=np.array(module.stride),
padding=np.array(module.padding)
)
}
# Finally, we call nirtorch with the node and dictionary
converted_module = nirtorch.torch_to_nir(torch_module, torch_to_nir_map)
converted_module
AvgPool2d(kernel_size=array([2, 2]), stride=array(0), padding=array(1))
Note that we convert the module parameters to Numpy arrays. You can use in principle the raw numbers, but we recommend using numpy arrays for consistency.
Note also that the mapping functions can output arbitrary NIR nodes, so you can output arbitrary nodes if you wish.
Mapping function calls#
The above method works well for modules, but what about modules with function calls like addition +
?
def my_add(x: torch.Tensor, y: torch.Tensor):
return x + y
This is where the difference between NIR and PyTorch becomes apparent: NIR is not a procedural graph, like PyTorch, where we execute one thing after the other and eventually return the output. Think about NIR as a physical switchboard where we plug wires into different sockets to form connections between nodes. This only works for a subset of functions.
Addition works well because we can rewire this PyTorch graph
flowchart LR plus[my_add] x --> plus y --> plus plus --> next_module
Into
flowchart LR x --> next_module y --> next_module
by removing the +
node and wire both x
and y
directly to the output.
This works, because “addition” in NIR corresponds to summing two signals.
Here’s an example of a module that returns the sum two linearities:
import nir, nirtorch, numpy as np, torch
# First, we describe the PyTorch module we want to convert, this time with addition
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(1, 1)
def forward(self, x):
return self.lin(x) + self.lin(x)
torch_module = MyModule()
# Second, we define the dictionary
torch_to_nir_map = {
torch.nn.Linear: lambda module: nir.Affine(
weight=module.weight.detach().numpy(),
bias=module.bias.detach().numpy()
)
}
# Finally, we call nirtorch with the node and dictionary
converted_module = nirtorch.torch_to_nir(torch_module, torch_to_nir_map)
converted_module
NIRGraph(nodes={'x': Input(input_type={'input': array([1])}), 'lin': Affine(weight=array([[-0.29883528]], dtype=float32), bias=array([-0.36035502], dtype=float32), input_type={'input': array([1])}, output_type={'output': array([1])}, metadata={}), 'lin_1': Affine(weight=array([[-0.29883528]], dtype=float32), bias=array([-0.36035502], dtype=float32), input_type={'input': array([1])}, output_type={'output': array([1])}, metadata={}), 'output': Output(output_type={'output': array([1])})}, edges=[('x', 'lin'), ('x', 'lin_1'), ('lin', 'output'), ('lin_1', 'output')], input_type={'x': {'input': array([1])}}, output_type={'output': {'output': array([1])}}, metadata={})
This corresponds to the following graph:
flowchart LR x --> lin x --> lin_1 lin --> output lin_1 --> output
Notice the two nodes (lin
and lin_1
) with two additional edges (lin, output)
and (lin_1, output)
.
This is “addition” in NIR, because the signals will sum upon arrival.
Other functions are, presently, not supported. Get in touch or open an issue if you think this should change! We are more than happy to hear your input and adapt to your needs.