hxtorch#
hxtorch hxtorch is a deep learning Python library used for numerical simulation, neuromorphic emulation and training of spiking neural networks (SNNs). Built on top of PyTorch, it integrates the automatic differentiation and modular design of the PyTorch ecosystem with neuromorphic experiment execution, enabling hardware-in-the-loop training workflows on the neuromorphic hardware system BrainScaleS-2.
Supported Primitives in hxtorch#
This library supports conversion of the following nodes to NIR:
Linear
CubaLI
CubaLIF
This library supports conversion of the following nodes from NIR:
Linear
CubaLI
CubaLIF
Import a NIR graph to NIR from hxtorch#
import hxtorch.spiking as hxsnn
from hxtorch.spiking.utils.to_nir import SNN, to_nir
import torch
# Your SNN definition in hxtorch
class One_Layer_SNN(SNN):
def __init__(self, dt: float = 1.0e-6, mock: bool = True,
device: torch.device = torch.device("cpu")):
"""
Initialize an SNN with one hidden LIF layer and an LI readout.
:param dt: Simulation time step in seconds.
:param mock: Whether to use simulate the neurons on CPU/GPU (False) or
use BrainScaleS-2 hardware (False).
:param device: Device to use for simulation if mock is True.
"""
super().__init__(dt, mock, device)
# Layers
self.linear_h = hxsnn.Synapse(in_features=5, out_features=10, experiment=self.exp)
self.lif_h = hxsnn.LIF(size=10, experiment=self.exp)
self.linear_o = hxsnn.Synapse(in_features=10, out_features=2, experiment=self.exp)
self.li_readout = hxsnn.LI(size=2, experiment=self.exp)
def forward(self, spikes: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass for an SNN with one hidden LIF layer and an LI
readout.
:param spikes: torch.Tensor holding spikes as input.
:return: Returns the output of the network, i.e. membrane traces of the
readout neurons.
"""
# Spike input handle
spikes_handle = hxsnn.LIFObservables(spikes=spikes)
# Forward
c_h = self.linear_h(spikes_handle)
self.s_h = self.lif_h(c_h)
c_o = self.linear_o(self.s_h)
y_o = self.li_readout(c_o)
# Execute on hardware
hxsnn.run(self.exp, spikes.shape[0])
return y_o.membrane_cadc
snn = One_Layer_SNN()
input_sample = torch.randint(0, 2, (10, 1, 5), dtype=torch.float32)
nir_graph = to_nir(snn, input_sample)
Export a NIR graph from NIR to hxtorch#
import hxtorch.spiking as hxsnn
from hxtorch.spiking.utils.from_nir import ConversionConfig
import nir
import numpy as np
import torch
cfg = ConversionConfig()
nir_graph = nir.NIRGraph(
nodes={
"input": nir.Input(input_type=np.array([3])),
"linear": nir.Linear(weight=np.random.rand(5, 3)),
"lif": nir.CubaLIF(
tau_mem=np.array([0.02] * 5),
tau_syn=np.array([0.005] * 5),
r=np.array([1.0] * 5),
v_leak=np.array([0.1] * 5),
v_reset=np.array([0.0] * 5),
v_threshold=np.array([1.0] * 5)
),
"output": nir.Output(output_type=np.array([5]))
},
edges=[
("input", "linear"),
("linear", "lif"),
("lif", "output")
]
)
hxtorch_network = hxsnn.from_nir(nir_graph, cfg)
# Example usage (10 time steps, 100 samples, input size 3)
sample_input = {"input": torch.randint(0, 2, (10, 100, 3), dtype=torch.float32)}
output = hxtorch_network(sample_input)
The resulting hxtorch_network is constructed such that the input must be
provided as a dict of the input node keys. The network’s output is also
returned as dict.