To PyTorch: Interpreting NIR#

We rely on torch.fx to interpret NIR graphs. We first translate all the NIR nodes into PyTorch nodes, by going through the nodes in the NIR graph one by one. This mechanism relies on a dictionary given by the user that tells nirtorch (1) which modules can be mapped and (2) how to map them. That is, we need a dictionary of type Dict[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]]. One may wonder why we don’t just use a function from nir.NIRNode -> torch.nn.Module and the answer is that having a set of nodes that do exist helps nirtorch simplify the parsing. It is entirely possible to only provide partial mappings, which nirtorch will handle by skipping those nodes - except in cases where the mapping is required.

Intepreting in detail#

The interpreting happens in two steps.

  1. First, we map all the nodes individually, using the dictionary defined above.

  2. Second, we trace the graph and plug the translated nodes into a Torch FX graph. The second step gives us a fully-executable torch.fx.GraphModule that we can execute. Note that the execution is stateful, as described in State management in NIRTorch.

All this is implemented in the function nirtorch.nir_to_torch which has the following signature:

def nir_to_torch(
    nir_node: nir.NIRGraph,
    node_map: NodeMapType,
    default_map: NodeMapType = DEFAULT_MAP,
) -> torch.fx.GraphModule:

The nir_node parameter is the NIR node we wish to convert. It can be a singular node or a nir.NIRGraph, which can contain multiple nodes (and subgraphs). The node_map parameter is the dictionary above, with type signature Dict[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]], that we use to look up supported nodes and convert nodes in the first step mentioned above.

Here is a short, self-contained example on how you can map a nir.AvgPool2d to torch.nn.AvgPool2d:

import nir, nirtorch, numpy as np, torch

# First, we describe the NIR graph we need as input
nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
nir_graph = nir.NIRGraph.from_list(nir_avgpool, nir_linear) # Constructs a graph with a single node: AvgPool2d

# Second, we define the mapping
nir_to_torch_map = {
    nir.AvgPool2d: lambda node: torch.nn.AvgPool2d(
        kernel_size=tuple(torch.from_numpy(node.kernel_size)),
        stride=torch.from_numpy(node.stride),
        padding=tuple(torch.from_numpy(node.padding))
    )
}

# Finally, we call nirtorch with the node and dictionary
converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map)
converted_module
GraphModule(
  (avgpool2d): AvgPool2d(kernel_size=(tensor(2), tensor(2)), stride=tensor([1]), padding=(tensor(0), tensor(0)))
  (linear): Linear(in_features=5, out_features=5, bias=False)
)

Note the torch.from_numpy call, which ensures that the Numpy arrays coming from NIR is correctly cast to PyTorch tensors. You may also have observed that we cast some of the parameters to torch.nn.AvgPool2d to tuples to adhere to the 2-dimensional average pooling arguments.

Note also that we did not specify a mapping for the nir.Linear module. That’s because nirtorch provides default mappings for the simples modules (like nir.Linear).

# Now, we can execute it with a 3-dimensional tensor arranged according to (batch, width, height)
converted_module(torch.ones(1, 10, 10))
(tensor([[[5., 5., 5., 5., 5.],
          [5., 5., 5., 5., 5.],
          [5., 5., 5., 5., 5.],
          [5., 5., 5., 5., 5.],
          [5., 5., 5., 5., 5.]]], grad_fn=<UnsafeViewBackward0>),
 {'input': None, 'avgpool2d': None, 'linear': None, 'output': None})

Note that the output is a tuple, where the second tuple is the state (which is empty, because average pooling is stateless).

Overwriting default dictionaries#

There is a third parameter, default_map, which serves to provide defaults to the mapping. nirtorch will, by default, map simple models, like nir.Input, nir.Linear -> torch.nn.Linear and nir.Affine -> torch.nn.Affine, but you can override the behavior if you want to provide a different mapping—or remove it all together. Observe what happens when we override the default dicionary (DEFAULT_MAP) with an empty dictionary:

import nir, nirtorch, numpy as np, torch

# First, we describe the NIR graph we need as input
nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
nir_graph = nir.NIRGraph.from_list(nir_avgpool, nir_linear) # Constructs a graph with a single node: AvgPool2d

# Second, we define the mapping
nir_to_torch_map = {
    nir.AvgPool2d: lambda node: torch.nn.AvgPool2d(
        kernel_size=tuple(torch.from_numpy(node.kernel_size)),
        stride=torch.from_numpy(node.stride),
        padding=tuple(torch.from_numpy(node.padding))
    )
}

# Finally, we call nirtorch with the node and dictionary
converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map, {})
converted_module
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[46], line 18
      9 nir_to_torch_map = {
     10     nir.AvgPool2d: lambda node: torch.nn.AvgPool2d(
     11         kernel_size=tuple(torch.from_numpy(node.kernel_size)),
   (...)
     14     )
     15 }
     17 # Finally, we call nirtorch with the node and dictionary
---> 18 converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map, {})
     19 converted_module

File ~/.local/lib/python3.12/site-packages/nirtorch/nir_interpreter.py:278, in nir_to_torch(nir_graph, node_map, default_map)
    276 map_with_defaults.update(node_map)  # Overwrite defaults with node_map
    277 # First convert all nodes into a module dictionary
--> 278 owning_module = _construct_module_dict_recursive(nir_graph, map_with_defaults)
    279 # Then wire the graph recursively
    280 return _construct_fx_graph(owning_module=owning_module, nir_graph=nir_graph)

File ~/.local/lib/python3.12/site-packages/nirtorch/nir_interpreter.py:91, in _construct_module_dict_recursive(nir_graph, node_map)
     89     owning_module[name] = _construct_module_dict_recursive(node, node_map)
     90 else:
---> 91     mapped_module = _map_nir_node_to_torch(node, node_map=node_map)
     92     if mapped_module is not None:
     93         owning_module[name] = mapped_module

File ~/.local/lib/python3.12/site-packages/nirtorch/nir_interpreter.py:77, in _map_nir_node_to_torch(node, node_map)
     75     return node_map[type(node)](node)
     76 else:
---> 77     raise ValueError(
     78         f"Unknown node type {type(node)}, mapping does not exist in node_map"
     79     )

ValueError: Unknown node type <class 'nir.ir.graph.Input'>, mapping does not exist in node_map

You may wonder why the graph has a nir.Input node. It’s automatically added when constructing a NIR graph via NIRGraph (which we do via nir.NIRGraph.from_list) to ensure that the graph is well formed and that torch knows where the input and output nodes are. Without the default mapping, nirtorch doesn’t know how to map the input node or the linear node and will complain.

How do I map tensors to specific devices?#

nirtorch does not care which device your tensors are located, but you may want to cast tensors to specific devices. This can be done with a partial function application, where you first define your mapping function with an additional device parameter, partially apply it when you know the device, and then provide that partially applied function to nirtorch. Here is an example:

import nir, nirtorch, numpy as np, torch
import functools

# First, we describe the NIR graph we need as input
nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
nir_graph = nir.NIRGraph.from_list(nir_avgpool, nir_linear) # Constructs a graph with a single node: AvgPool2d

# Second, we define the mapping
nir_to_torch_map = {
    nir.AvgPool2d: lambda node, device: torch.nn.AvgPool2d(    # <--- Note the additional device parameter
        kernel_size=tuple(torch.from_numpy(node.kernel_size).to(device)),
        stride=torch.from_numpy(node.stride).to(device),
        padding=tuple(torch.from_numpy(node.padding).to(device))
    )
}
# We can now partially apply the function at a point in time where we know the device type
nir_to_torch_map[nir.AvgPool2d] = functools.partial(
    nir_to_torch_map[nir.AvgPool2d], 
    device="cpu"
)
# The dictionary now contains a partially applied function that only requires one input: the NIR node,
# so it is safe to pass onto nirtorch
converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map)
converted_module
GraphModule(
  (avgpool2d): AvgPool2d(kernel_size=(tensor(2), tensor(2)), stride=tensor([1]), padding=(tensor(0), tensor(0)))
  (linear): Linear(in_features=5, out_features=5, bias=False)
)

You can see a live example of this pattern in the NIR implementation for the spiking neural network library Norse.