To PyTorch: Interpreting NIR#
We rely on torch.fx
to interpret NIR graphs.
We first translate all the NIR nodes into PyTorch nodes, by going through the nodes in the NIR graph one by one.
This mechanism relies on a dictionary given by the user that tells nirtorch
(1) which modules can be mapped and (2) how to map them.
That is, we need a dictionary of type Dict[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]]
.
One may wonder why we don’t just use a function from nir.NIRNode -> torch.nn.Module
and the answer is that having a set of nodes that do exist helps nirtorch
simplify the parsing.
It is entirely possible to only provide partial mappings, which nirtorch
will handle by skipping those nodes - except in cases where the mapping is required.
Intepreting in detail#
The interpreting happens in two steps.
First, we map all the nodes individually, using the dictionary defined above.
Second, we trace the graph and plug the translated nodes into a Torch FX graph. The second step gives us a fully-executable torch.fx.GraphModule that we can execute. Note that the execution is stateful, as described in State management in NIRTorch.
All this is implemented in the function nirtorch.nir_to_torch
which has the following signature:
def nir_to_torch(
nir_node: nir.NIRGraph,
node_map: NodeMapType,
default_map: NodeMapType = DEFAULT_MAP,
) -> torch.fx.GraphModule:
The nir_node
parameter is the NIR node we wish to convert.
It can be a singular node or a nir.NIRGraph
, which can contain multiple nodes (and subgraphs).
The node_map
parameter is the dictionary above, with type signature Dict[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]]
, that we use to look up supported nodes and convert nodes in the first step mentioned above.
Here is a short, self-contained example on how you can map a nir.AvgPool2d
to torch.nn.AvgPool2d
:
import nir, nirtorch, numpy as np, torch
# First, we describe the NIR graph we need as input
nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
nir_graph = nir.NIRGraph.from_list(nir_avgpool, nir_linear) # Constructs a graph with a single node: AvgPool2d
# Second, we define the mapping
nir_to_torch_map = {
nir.AvgPool2d: lambda node: torch.nn.AvgPool2d(
kernel_size=tuple(torch.from_numpy(node.kernel_size)),
stride=torch.from_numpy(node.stride),
padding=tuple(torch.from_numpy(node.padding))
)
}
# Finally, we call nirtorch with the node and dictionary
converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map)
converted_module
GraphModule(
(avgpool2d): AvgPool2d(kernel_size=(tensor(2), tensor(2)), stride=tensor([1]), padding=(tensor(0), tensor(0)))
(linear): Linear(in_features=5, out_features=5, bias=False)
)
Note the torch.from_numpy
call, which ensures that the Numpy arrays coming from NIR is correctly cast to PyTorch tensors.
You may also have observed that we cast some of the parameters to torch.nn.AvgPool2d
to tuples to adhere to the 2-dimensional average pooling arguments.
Note also that we did not specify a mapping for the nir.Linear
module. That’s because nirtorch
provides default mappings for the simples modules (like nir.Linear
).
# Now, we can execute it with a 3-dimensional tensor arranged according to (batch, width, height)
converted_module(torch.ones(1, 10, 10))
(tensor([[[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.]]], grad_fn=<UnsafeViewBackward0>),
{'input': None, 'avgpool2d': None, 'linear': None, 'output': None})
Note that the output is a tuple, where the second tuple is the state (which is empty, because average pooling is stateless).
Overwriting default dictionaries#
There is a third parameter, default_map
, which serves to provide defaults to the mapping.
nirtorch
will, by default, map simple models, like nir.Input
, nir.Linear -> torch.nn.Linear
and nir.Affine -> torch.nn.Affine
, but you can override the behavior if you want to provide a different mapping—or remove it all together.
Observe what happens when we override the default dicionary (DEFAULT_MAP
) with an empty dictionary:
import nir, nirtorch, numpy as np, torch
# First, we describe the NIR graph we need as input
nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
nir_graph = nir.NIRGraph.from_list(nir_avgpool, nir_linear) # Constructs a graph with a single node: AvgPool2d
# Second, we define the mapping
nir_to_torch_map = {
nir.AvgPool2d: lambda node: torch.nn.AvgPool2d(
kernel_size=tuple(torch.from_numpy(node.kernel_size)),
stride=torch.from_numpy(node.stride),
padding=tuple(torch.from_numpy(node.padding))
)
}
# Finally, we call nirtorch with the node and dictionary
converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map, {})
converted_module
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[46], line 18
9 nir_to_torch_map = {
10 nir.AvgPool2d: lambda node: torch.nn.AvgPool2d(
11 kernel_size=tuple(torch.from_numpy(node.kernel_size)),
(...)
14 )
15 }
17 # Finally, we call nirtorch with the node and dictionary
---> 18 converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map, {})
19 converted_module
File ~/.local/lib/python3.12/site-packages/nirtorch/nir_interpreter.py:278, in nir_to_torch(nir_graph, node_map, default_map)
276 map_with_defaults.update(node_map) # Overwrite defaults with node_map
277 # First convert all nodes into a module dictionary
--> 278 owning_module = _construct_module_dict_recursive(nir_graph, map_with_defaults)
279 # Then wire the graph recursively
280 return _construct_fx_graph(owning_module=owning_module, nir_graph=nir_graph)
File ~/.local/lib/python3.12/site-packages/nirtorch/nir_interpreter.py:91, in _construct_module_dict_recursive(nir_graph, node_map)
89 owning_module[name] = _construct_module_dict_recursive(node, node_map)
90 else:
---> 91 mapped_module = _map_nir_node_to_torch(node, node_map=node_map)
92 if mapped_module is not None:
93 owning_module[name] = mapped_module
File ~/.local/lib/python3.12/site-packages/nirtorch/nir_interpreter.py:77, in _map_nir_node_to_torch(node, node_map)
75 return node_map[type(node)](node)
76 else:
---> 77 raise ValueError(
78 f"Unknown node type {type(node)}, mapping does not exist in node_map"
79 )
ValueError: Unknown node type <class 'nir.ir.graph.Input'>, mapping does not exist in node_map
You may wonder why the graph has a nir.Input
node.
It’s automatically added when constructing a NIR graph via NIRGraph
(which we do via nir.NIRGraph.from_list
) to ensure that the graph is well formed and that torch knows where the input and output nodes are.
Without the default mapping, nirtorch
doesn’t know how to map the input node or the linear node and will complain.
How do I map tensors to specific devices?#
nirtorch
does not care which device your tensors are located, but you may want to cast tensors to specific devices.
This can be done with a partial function application, where you first define your mapping function with an additional device
parameter, partially apply it when you know the device, and then provide that partially applied function to nirtorch
. Here is an example:
import nir, nirtorch, numpy as np, torch
import functools
# First, we describe the NIR graph we need as input
nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
nir_graph = nir.NIRGraph.from_list(nir_avgpool, nir_linear) # Constructs a graph with a single node: AvgPool2d
# Second, we define the mapping
nir_to_torch_map = {
nir.AvgPool2d: lambda node, device: torch.nn.AvgPool2d( # <--- Note the additional device parameter
kernel_size=tuple(torch.from_numpy(node.kernel_size).to(device)),
stride=torch.from_numpy(node.stride).to(device),
padding=tuple(torch.from_numpy(node.padding).to(device))
)
}
# We can now partially apply the function at a point in time where we know the device type
nir_to_torch_map[nir.AvgPool2d] = functools.partial(
nir_to_torch_map[nir.AvgPool2d],
device="cpu"
)
# The dictionary now contains a partially applied function that only requires one input: the NIR node,
# so it is safe to pass onto nirtorch
converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map)
converted_module
GraphModule(
(avgpool2d): AvgPool2d(kernel_size=(tensor(2), tensor(2)), stride=tensor([1]), padding=(tensor(0), tensor(0)))
(linear): Linear(in_features=5, out_features=5, bias=False)
)
You can see a live example of this pattern in the NIR implementation for the spiking neural network library Norse.