State management in NIRTorch#
Managing state is crucial when dealing with (real) neurons because we have to maintain a state, such as a membrane potential, leak, or otherwise. There are two main ways of doing that: implicitly and explicitly. NIRTorch uses explicit state handling. Below, we briefly explain the difference and show how the state handling works in practice when executing NIRTorch graphs.
Implicit vs Explicit State: A High-Level Comparison#
This high-level comparison shows the fundamental differences between implicit and explicit state handling approaches. Let’s break down each aspect:
Control of State:
Implicit: The framework manages state changes
Explicit: The developer directly controls state transitions
Visibility:
Implicit: State changes can happen automatically
Explicit: State changes must be explicitly coded
Traceability:
Implicit: State transitions may be hidden
Explicit: Clear state transition flow
Dataflow Patterns#
To understand how these patterns work in practice, let’s examine their dataflow characteristics:
flowchart LR subgraph Implicit["Module with Internal State"] direction TB I_Input[/"Input Data"/] I_State[("Internal State")] I_Process["Process"] I_Output[/"Output"/] I_Input --> I_Process I_Process <--> I_State I_Process --> I_Output style I_State fill:#ff9999 end subgraph Explicit["Module with External State"] direction TB E_Input[/"Input Data"/] E_State[("Current State")] E_Process["Process"] E_Output[/"Output + New State"/] E_Input --> E_Process E_State --> E_Process E_Process --> E_Output style E_State fill:#99ff99 end ImplicitNote["State lives inside module and is mutated during processing. The user never sees the state, but may have to reset it"] ExplicitNote["State flows through module as input/output, never mutated internally"] Implicit -.-> ImplicitNote Explicit -.-> ExplicitNote style ImplicitNote fill:#fff,stroke:#999 style ExplicitNote fill:#fff,stroke:#999
Advantages and Trade-offs#
Implicit State#
Advantages:
Less boilerplate code
Can feel more intuitive for simple applications
Automatic state synchronization
Disadvantages:
Harder to test
State changes can be difficult to track
Can lead to unexpected side effects
Explicit State#
Advantages:
Predictable data flow
Easier to test
Clear state transitions
Better debugging experience
Disadvantages:
More verbose
Can feel overengineered for simple cases
State handling in Python and NIRTorch#
NIRTorch uses explicit state management, which may be more cumbersome to write but makes data flow more visible:
class MyState:
voltage: float
def stateful_function(data, state):
# 1. Calculate a new voltage
new_voltage = ...
# 2. Calculate the function output
output = ...
# 3. Define a new state
new_state = MyState(voltage=new_voltage)
# 4. A tuple of (data, state) is returned
# Note that the new state returned and the original remains unchanged
return output, new_state
Once NIRTorch has parsed a NIR module into Torch modules (read more about that in the page about To PyTorch: Interpreting NIR),
the resulting module expects a second state
parameter, like the function above.
Similarly, it will return a tuple of (data, state)
.
Here is a full example where we first initialize a Torch module from NIRTorch, and then applies it several times with the correct state
import nir
import nirtorch
import numpy as np
nir_weight = np.ones((2, 2))
nir_graph = nir.NIRGraph.from_list(nir.Linear(weight=nir_weight))
torch_module = nirtorch.nir_to_torch(
nir_graph=nir_graph,
node_map={} # We can leave this empty since we only
# use a linear layer which has a default mapping
)
##
## This is where the state handling happens
##
# Assume some time-series data with 100 entries each with two datapoints
time_data = torch.random(100, 2)
state = None
results = []
# Loop through the time-series data, one entry at the time
for single_data in time_data:
if state is None:
# If state is None, we leave the state blank
output, state = torch_module(single_data)
else:
# If state is not None, we need to feed it back in
output, state = torch_module(single_data, state)
results.append(output)
# results is now a tensor of shape (100, 2)
results = torch.stack(results)