NIRTorch API Documentation#

This page lists functions and classes exposed by nirtorch and their corresponding documentation strings.

NIRTorch functions#

nirtorch.extract_nir_graph(model: Module, model_map: Callable[[Module], NIRNode], sample_data: Any, model_name: str | None = 'model', ignore_submodules_of=None, model_fwd_args=[], ignore_dims: Sequence[int] | None = None) NIRNode#

DEPRECATED: Use nirtorch.torch_to_nir instead.

Given a model, generate an NIR representation using the specified model_map.

Assumptions and known issues:
  • Cannot deal with layers like torch.nn.Identity(), since the input tensor and output tensor will be the same object, and therefore lead to cyclic connections.

Args:

model (nn.Module): The model of interest model_map (Callable[[nn.Module], nir.NIRNode]): A method that converts a given

module type to an NIRNode type

sample_data (Any): Sample input data to be used for model extraction model_name (Optional[str], optional): The name of the top level module.

Defaults to “model”.

ignore_submodules_of (Optional[Sequence[nn.Module]]): If specified,

the corresponding module’s children will not be traversed for graph.

ignore_dims (Optional[Sequence[int]]): Dimensions of data to be ignored for

type/shape inference. Typically the dimensions that you will want to ignore are for batch and time.

Returns:

nir.NIR: Returns the generated NIR graph representation.

nirtorch.load(nir_graph: NIRNode | str, model_map: Callable[[NIRNode], Module], return_state: bool = True) Module#

DEPRECATED: Use nirtorch.torch_to_nir instead.

Load a NIR graph and convert it to a torch module using the given model map.

If you do not wish to operate with state, set return_state=False.

Args:
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string

representing the path to the NIR object.

model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch

module that corresponds to each NIR node.

return_state (bool): If True, the execution of the loaded graph will return a

tuple of [output, state], where state is a GraphExecutorState object. If False, only the NIR graph output will be returned. Note that state is required for recurrence to work in the graphs.

Returns:

nn.Module: The generated torch module

nirtorch.nir_to_torch(nir_node: ~nir.ir.node.NIRNode, node_map: ~typing.Dict[~nir.ir.node.NIRNode, ~typing.Callable[[~nir.ir.node.NIRNode], ~torch.nn.modules.module.Module]], default_map: ~typing.Dict[~nir.ir.node.NIRNode, ~typing.Callable[[~nir.ir.node.NIRNode], ~torch.nn.modules.module.Module]] = {<class 'nir.ir.conv.Conv1d'>: <function _default_map_conv1d>, <class 'nir.ir.conv.Conv2d'>: <function _default_map_conv2d>, <class 'nir.ir.graph.Input'>: <function <lambda>>, <class 'nir.ir.graph.Output'>: <function <lambda>>, <class 'nir.ir.linear.Affine'>: <function _default_map_affine>, <class 'nir.ir.linear.Linear'>: <function _default_map_linear>}) GraphModule#

Maps a NIRGraph as an executable PyTorch GraphModule (torch.fx.GraphModule). We first map all individual nodes using the node_map, where a common set of mappings are provided by default (e. g. Linear, Conv, etc.) Then, we wire all the nodes together into an executable torch.fx.GraphModule. Finally, we wrap the execution in a StatefulInterpreter, to ensure that the internal state of modules are handled correctly.

Example:

>>> # First, we describe the NIR graph
>>> nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
>>> nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
>>> nir_graph = nir.NIRGraph.from_list(nir_avgpool, nir_linear) # Constructs a graph with a single node: AvgPool2d
>>> # Second, we define the mapping
>>> nir_to_torch_map = {
>>>     nir.AvgPool2d: lambda node: torch.nn.AvgPool2d(
>>>         kernel_size=tuple(torch.from_numpy(node.kernel_size)),
>>>         stride=torch.from_numpy(node.stride),
>>>         padding=tuple(torch.from_numpy(node.padding))
>>>     )
>>> }
>>> # Finally, we call nirtorch with the node and dictionary
>>> converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map)
>>> output, state = torch_module(torch.ones(1)) # Note the state return a tuple of (value, state)
>>> output, state = torch_module(input, state)  # This can go on for many (time)steps
Args:

nir_node (nir.NIRNode): The input NIR node to convert to torch, typically a nir.NIRGraph node_map (Dict[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]]): A dictionary that

maps NIR nodes into Torch modules.

default_map (Dict[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]]): A dictionary with

default values to use in case where node_map entries are missing. The default value of this parameter defines mappings for simple modules like nir.Linear and nir.Input. Override this to provide custom defaults.

nirtorch.torch_to_nir(module: ~torch.nn.modules.module.Module, module_map: ~typing.Dict[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.nn.modules.module.Module], ~nir.ir.node.NIRNode]], default_dict: ~typing.Dict[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.nn.modules.module.Module], ~nir.ir.node.NIRNode]] = {<class 'torch.nn.modules.linear.Linear'>: <function <lambda>>}) NIRGraph#

Traces a PyTorch module and converts it to a NIR graph using the specified module map.

>>> import nir, nirtorch, numpy as np, torch
>>> # First, we describe the PyTorch module we want to convert
>>> torch_module = torch.nn.AvgPool2d(kernel_size=(2, 2), stride=0, padding=1)
>>> # Second, we define the dictionary
>>> torch_to_nir_map = {
>>>     torch.nn.AvgPool2d: lambda module: nir.AvgPool2d(
>>>         kernel_size=np.array(module.kernel_size),
>>>         stride=np.array(module.stride),
>>>         padding=np.array(module.padding)
>>>     )
>>> }
>>> # Finally, we call nirtorch with the node and dictionary
>>> torch_module = nirtorch.torch_to_nir(torch_module, torch_to_nir_map)
>>> output, state = torch_module(input)        # Note the module returns a tuple of (output, state)
>>> output, state = torch_module(input, state) # This can go on for many (time)steps
Args:

module (torch.nn.Module): The module of interest module_map (Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]): A dictionary that maps

a given module type to a function that can convert the model to an NIRNode type

default_dict (Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]): An dictionary

of default mappings that, by default, maps trivial modules like torch.nn.Linear. Override the dictionary to provide custom mappings.