hxtorch#

hxtorch hxtorch is a deep learning Python library used for numerical simulation, neuromorphic emulation and training of spiking neural networks (SNNs). Built on top of PyTorch, it integrates the automatic differentiation and modular design of the PyTorch ecosystem with neuromorphic experiment execution, enabling hardware-in-the-loop training workflows on the neuromorphic hardware system BrainScaleS-2.

Supported Primitives in hxtorch#

This library supports conversion of the following nodes to NIR:

  • Linear

  • CubaLI

  • CubaLIF

This library supports conversion of the following nodes from NIR:

  • Linear

  • CubaLI

  • CubaLIF

Import a NIR graph to NIR from hxtorch#

import hxtorch.spiking as hxsnn
from hxtorch.spiking.utils.to_nir import SNN, to_nir
import torch

# Your SNN definition in hxtorch
class One_Layer_SNN(SNN):
    def __init__(self, dt: float = 1.0e-6, mock: bool = True,
                 device: torch.device = torch.device("cpu")):
        """
        Initialize an SNN with one hidden LIF layer and an LI readout.
        
        :param dt: Simulation time step in seconds.
        :param mock: Whether to use simulate the neurons on CPU/GPU (False) or 
            use BrainScaleS-2 hardware (False).
        :param device: Device to use for simulation if mock is True.
        """
        super().__init__(dt, mock, device)

        # Layers
        self.linear_h = hxsnn.Synapse(in_features=5, out_features=10, experiment=self.exp)
        self.lif_h = hxsnn.LIF(size=10, experiment=self.exp)
        self.linear_o = hxsnn.Synapse(in_features=10, out_features=2, experiment=self.exp)
        self.li_readout = hxsnn.LI(size=2, experiment=self.exp)

    def forward(self, spikes: torch.Tensor) -> torch.Tensor:
        """
        Perform a forward pass for an SNN with one hidden LIF layer and an LI 
        readout.

        :param spikes: torch.Tensor holding spikes as input.

        :return: Returns the output of the network, i.e. membrane traces of the
            readout neurons.
        """
        # Spike input handle
        spikes_handle = hxsnn.LIFObservables(spikes=spikes)

        # Forward
        c_h = self.linear_h(spikes_handle)
        self.s_h = self.lif_h(c_h)
        c_o = self.linear_o(self.s_h)
        y_o = self.li_readout(c_o)

        # Execute on hardware
        hxsnn.run(self.exp, spikes.shape[0])

        return y_o.membrane_cadc

snn = One_Layer_SNN()
input_sample = torch.randint(0, 2, (10, 1, 5), dtype=torch.float32)
nir_graph = to_nir(snn, input_sample)

Export a NIR graph from NIR to hxtorch#

import hxtorch.spiking as hxsnn
from hxtorch.spiking.utils.from_nir import ConversionConfig
import nir
import numpy as np
import torch

cfg = ConversionConfig()

nir_graph = nir.NIRGraph(
    nodes={
        "input": nir.Input(input_type=np.array([3])),
        "linear": nir.Linear(weight=np.random.rand(5, 3)),
        "lif": nir.CubaLIF(
            tau_mem=np.array([0.02] * 5),
            tau_syn=np.array([0.005] * 5),
            r=np.array([1.0] * 5),
            v_leak=np.array([0.1] * 5),
            v_reset=np.array([0.0] * 5),
            v_threshold=np.array([1.0] * 5)
        ),
        "output": nir.Output(output_type=np.array([5]))
    },
    edges=[
        ("input", "linear"),
        ("linear", "lif"),
        ("lif", "output")
    ]
)

hxtorch_network = hxsnn.from_nir(nir_graph, cfg)

# Example usage (10 time steps, 100 samples, input size 3)
sample_input = {"input": torch.randint(0, 2, (10, 100, 3), dtype=torch.float32)}
output = hxtorch_network(sample_input)

The resulting hxtorch_network is constructed such that the input must be provided as a dict of the input node keys. The network’s output is also returned as dict.