NIRTorch API Documentation#

This page lists functions and classes exposed by nirtorch and their corresponding documentation strings.

NIRTorch functions#

nirtorch.extract_nir_graph(model: Module, model_map: Callable[[Module], NIRNode], sample_data: Any, model_name: str | None = 'model', ignore_submodules_of=None, model_fwd_args=[], ignore_dims: Sequence[int] | None = None) NIRNode#

DEPRECATED: Use nirtorch.torch_to_nir instead.

Given a model, generate an NIR representation using the specified model_map.

Assumptions and known issues:
  • Cannot deal with layers like torch.nn.Identity(), since the input tensor and output tensor will be the same object, and therefore lead to cyclic connections.

Args:

model (nn.Module): The model of interest model_map (Callable[[nn.Module], nir.NIRNode]): A method that converts a given

module type to an NIRNode type

sample_data (Any): Sample input data to be used for model extraction model_name (Optional[str], optional): The name of the top level module.

Defaults to “model”.

ignore_submodules_of (Optional[Sequence[nn.Module]]): If specified,

the corresponding module’s children will not be traversed for graph.

ignore_dims (Optional[Sequence[int]]): Dimensions of data to be ignored for

type/shape inference. Typically the dimensions that you will want to ignore are for batch and time.

Returns:

nir.NIR: Returns the generated NIR graph representation.

nirtorch.load(nir_graph: NIRNode | str, model_map: Callable[[NIRNode], Module], return_state: bool = True) Module#

DEPRECATED: Use nirtorch.nir_to_torch instead.

Load a NIR graph and convert it to a torch module using the given model map.

If you do not wish to operate with state, set return_state=False.

Args:
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string

representing the path to the NIR object.

model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch

module that corresponds to each NIR node.

return_state (bool): If True, the execution of the loaded graph will return a

tuple of [output, state], where state is a GraphExecutorState object. If False, only the NIR graph output will be returned. Note that state is required for recurrence to work in the graphs.

Returns:

nn.Module: The generated torch module

nirtorch.nir_to_torch(nir_node: str | ~nir.ir.node.NIRNode, node_map: ~typing.Dict[~nir.ir.node.NIRNode, ~typing.Callable[[~nir.ir.node.NIRNode], ~torch.nn.modules.module.Module]], default_map: ~typing.Dict[~nir.ir.node.NIRNode, ~typing.Callable[[~nir.ir.node.NIRNode], ~torch.nn.modules.module.Module]] = {<class 'nir.ir.conv.Conv1d'>: <function _default_map_conv1d>, <class 'nir.ir.conv.Conv2d'>: <function _default_map_conv2d>, <class 'nir.ir.flatten.Flatten'>: <function _default_map_flatten>, <class 'nir.ir.graph.Input'>: <function <lambda>>, <class 'nir.ir.graph.Output'>: <function <lambda>>, <class 'nir.ir.linear.Affine'>: <function _default_map_affine>, <class 'nir.ir.linear.Linear'>: <function _default_map_linear>, <class 'nir.ir.pooling.AvgPool2d'>: <function _default_map_avgpool2d>, <class 'nir.ir.pooling.SumPool2d'>: <function _default_map_sumpool2d>}, device: ~torch.device = 'cpu', dtype: ~torch.dtype = torch.float32) GraphModule#

Maps a NIRGraph as an executable PyTorch GraphModule (torch.fx.GraphModule). We first map all individual nodes using the node_map, where a common set of mappings are provided by default (e. g. Linear, Conv, etc.) Then, we wire all the nodes together into an executable torch.fx.GraphModule. Finally, we wrap the execution in a StatefulInterpreter, to ensure that the internal state of modules are handled correctly.

Example:

>>> # Using an existing graph
>>> nir_graph = ...
>>> torch_module = nirtorch.nir_to_torch(nir_graph)
>>> torch.module(torch.randn(...)) # The module is now ready to use
>>> # Using a custom graph
>>> nir_avgpool = nir.AvgPool2d(kernel_size=np.array([2, 2]), stride=np.array([1]), padding=np.array([0, 0]))
>>> nir_linear = nir.Linear(weight=np.ones((5, 5), dtype=np.float32))
>>> nir_graph = nir.NIRGraph.from_list(nir_avgpool, nir_linear) # Constructs a graph with a single node: AvgPool2d
>>> # Second, we define the mapping
>>> nir_to_torch_map = {
>>>     nir.AvgPool2d: lambda node: torch.nn.AvgPool2d(
>>>         kernel_size=tuple(torch.from_numpy(node.kernel_size)),
>>>         stride=torch.from_numpy(node.stride),
>>>         padding=tuple(torch.from_numpy(node.padding))
>>>     )
>>> }
>>> # Finally, we call nirtorch with the node and dictionary
>>> converted_module = nirtorch.nir_to_torch(nir_graph, nir_to_torch_map)
>>> output, state = torch_module(torch.ones(1)) # Note the state return a tuple of (value, state)
>>> output, state = torch_module(input, state)  # This can go on for many (time)steps
>>>                                             # Note that state is mutable!
Args:
nir_node (Union[nir.NIRNode, str, pathlib.Path]): The input NIR node to convert to torch, typically a nir.NIRGraph.

Can also be a string or a path, in which case, we use nir.read to fetch the graph from the file first.

node_map (Dict[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]]): A dictionary that

maps NIR nodes into Torch modules.

default_map (Dict[nir.NIRNode, Callable[[nir.NIRNode], torch.nn.Module]]): A dictionary with

default values to use in case where node_map entries are missing. The default value of this parameter defines mappings for simple modules like nir.Linear and nir.Input. Override this to provide custom defaults.

device (torch.device): The device to load the modules and parameters on. Defaults to “cpu” dtype (torch.dtype): The precision with which to load the modules and parameters. Defaults to torch.float32

nirtorch.torch_to_nir(module: ~torch.nn.modules.module.Module, module_map: ~typing.Dict[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.nn.modules.module.Module], ~nir.ir.node.NIRNode]], default_dict: ~typing.Dict[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.nn.modules.module.Module], ~nir.ir.node.NIRNode]] = {<class 'torch.nn.modules.linear.Linear'>: <function _map_linear>}, type_check: bool = True) NIRGraph#

Traces a PyTorch module and converts it to a NIR graph using the specified module map.

>>> import nir, nirtorch, numpy as np, torch
>>> # First, we describe the PyTorch module we want to convert
>>> torch_module = torch.nn.AvgPool2d(kernel_size=(2, 2), stride=0, padding=1)
>>> # Second, we define the dictionary
>>> torch_to_nir_map = {
>>>     torch.nn.AvgPool2d: lambda module: nir.AvgPool2d(
>>>         kernel_size=np.array(module.kernel_size),
>>>         stride=np.array(module.stride),
>>>         padding=np.array(module.padding)
>>>     )
>>> }
>>> # Finally, we call nirtorch with the node and dictionary
>>> torch_module = nirtorch.torch_to_nir(torch_module, torch_to_nir_map)
>>> output, state = torch_module(input)        # Note the module returns a tuple of (output, state)
>>> output, state = torch_module(input, state) # This can go on for many (time)steps
Args:

module (torch.nn.Module): The module of interest module_map (Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]): A dictionary that maps

a given module type to a function that can convert the model to an NIRNode type

default_dict (Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]): An dictionary

of default mappings that, by default, maps trivial modules like torch.nn.Linear. Override the dictionary to provide custom mappings.

type_check (bool): Whether to run type checking on generated NIRGraphs