SpiNNaker2

SpiNNaker2#

SpiNNaker2 is a neuromorphic chip based around a grid of ARM Cortex-M4F processors which are tighly coupled with accelerators and a network-on-chip optimized for, but not limited to transmission of spikes.

Running this requires the following library and installed SpiNNaker2 hardware: py-spinnaker2. Currently the branch py-spinnaker2_nir is required.

Create a graph#

import numpy as np

import nir

nir_model = nir.NIRGraph(
    nodes={
        "in": nir.Input(input_type=np.array([3])),
        "affine": nir.Affine(
            weight=np.array([[8, 2, 10], [14, 3, 14]]),
            bias=np.array([0, 8]),
        ),
        "lif": nir.LIF(
            tau=np.array([4]*2),
            r=np.array([1.25]*2),
            v_leak=np.array([0.5]*2),
            v_threshold=np.array([5]*2),
        ),
        "out": nir.Output(output_type=np.array([2])),
    },
    edges=[("in", "affine"), ("affine", "lif"), ("lif", "out")],
)
nir.write("nir_model.hdf5", nir_model)
print(nir_model)

Import to Spinnaker2#

import matplotlib.pyplot as plt
import numpy as np
from spinnaker2 import hardware, helpers, s2_nir, brian2_sim

import nir

np.random.seed(123)

timesteps = 50

# load NIR model
nir_model = nir.read("nir_model.hdf5")

# make sure all nodes have necessary shape information
# as not all shapes might be stored in the NIR graph
nir_model.infer_types()

print(nir_model)

# Configuration for converting NIR graph to SpiNNaker2
conversion_cfg = s2_nir.ConversionConfig()
conversion_cfg.output_record = ["v", "spikes"]
conversion_cfg.dt = 1
conversion_cfg.conn_delay = 0
conversion_cfg.scale_weights = True # Scale weights to dynamic range on chip
conversion_cfg.reset = s2_nir.ResetMethod.ZERO # Reset voltage to zero at spike
conversion_cfg.integrator = s2_nir.IntegratorMethod.FORWARD # Euler-Forward

net, inp, outp = s2_nir.from_nir(nir_model, conversion_cfg)


# Create some input spikes
input_size = inp[0].size
input_spikes = {}
input_data = np.random.randn(input_size, timesteps)
input_data = (input_data > 1) * 1
print(input_data)

for i in range(input_size):
    input_spikes[i] = input_data[i].nonzero()[0].tolist()

inp[0].params = input_spikes


# Load up hardware + run
# Use this to run on a SpiNNaker2 board, enter its correct IP address:
#hw = hardware.SpiNNaker2Chip(eth_ip="192.168.1.x")
# To instead run on Brian2 simulator:
hw = brian2_sim.Brian2Backend()

timesteps += 3
hw.run(net, timesteps)


# get results and plot
spike_times = outp[0].get_spikes()
voltages = outp[0].get_voltages()

fig, axs = plt.subplots(1, 3, sharex=True)

indices, times = helpers.spike_times_dict_to_arrays(input_spikes)
axs[0].plot(times, indices, ".")
axs[0].set_xlim(0, timesteps)
axs[0].set_ylim(-0.5, len(input_spikes.keys()) - 0.5)
axs[0].set_ylabel("neuron")

indices, times = helpers.spike_times_dict_to_arrays(spike_times)
axs[1].plot(times, indices, ".")
axs[1].set_xlim(0, timesteps)
axs[1].set_ylim(-0.5, outp[0].size - 0.5)
axs[1].set_ylabel("neuron")

for i, _ in voltages.items():
    axs[2].plot(_, label=i)
    axs[2].set_xlabel("time step")
    axs[2].set_ylabel("membrane potential")
    axs[2].set_xlim(0, timesteps)
    axs[2].legend()
    axs[2].grid()

plt.show()