Open In Colab

snnTorch to Norse (Sim2Sim)#

NIR for deep Spiking Neural Networks - From snnTorch to Norse#

Written by Jason Eshraghian and Bernhard Vogginger#

What you will learn:

  • Learn how spiking neurons are implemented as a recurrent network

  • Download event-based data and train a spiking neural network with it

  • Export it to the neuromorphic intermediate representation

  • Import it to Norse and run inference

Install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing Shift+Enter.

1. Imports#

!pip install snntorch --quiet
!pip install tonic --quiet
# imports
import snntorch as snn

import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

2. Handling Event-based Data with Tonic#

2.1 PokerDVS Dataset#

The dataset used in this tutorial is POKERDVS by T. Serrano-Gotarredona and B. Linares-Barranco:

Serrano-Gotarredona, Teresa, and Bernabé Linares-Barranco. "Poker-DVS and MNIST-DVS. Their history, how they were made, and other details." Frontiers in neuroscience 9 (2015): 481.

It is comprised of four classes, each being a suite of a playing card deck: clubs, spades, hearts, and diamonds. The data consists of 131 poker pip symbols, and was collected by flipping poker cards in front of a DVS128 camera.

import tonic

poker_train = tonic.datasets.POKERDVS(save_to='./data', train=True)
poker_test = tonic.datasets.POKERDVS(save_to='./data', train=False)

events, target = poker_train[0]
print(events)
tonic.utils.plot_event_grid(events)
import tonic.transforms as transforms
from tonic import DiskCachedDataset

# time_window
frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=10000),
                                            tonic.transforms.ToFrame(
                                            sensor_size=tonic.datasets.POKERDVS.sensor_size,
                                            time_window=1000)
                                            ])

batch_size = 8
cached_trainset = DiskCachedDataset(poker_train, transform=frame_transform, cache_path='./cache/pokerdvs/train')
cached_testset = DiskCachedDataset(poker_test, transform=frame_transform, cache_path='./cache/pokerdvs/test')

train_loader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)

data, labels = next(iter(train_loader))
print(data.size())
print(labels)

3. Define the SNN#

num_inputs = 35*35*2
num_hidden = 128
num_outputs = 4
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In the following code-block, note how the decay rate beta has two alternative definitions:

  • beta1 is set to a global decay rate for all neurons in the first spiking layer.

  • beta2 is randomly initialized to a vector of 10 different numbers. Each spiking neuron in the output layer (which not-so-coincidentally has 10 neurons) therefore has a unique, and random, decay rate.

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        alpha1 = 0.5
        beta1 = 0.9 # global decay rate for all leaky neurons in layer 1
        beta2 = torch.rand((num_outputs), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)
        threshold2 = torch.ones_like(beta2) # threshold parameter must have the same shape as beta for NIR
        alpha2 = torch.ones_like(beta2)*0.9

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Synaptic(alpha=alpha1, beta=beta1) # not a learnable decay rate
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Synaptic(alpha=alpha2, beta=beta2, threshold=threshold2, learn_beta=True) # learnable decay rate

    def forward(self, x):
        syn1, mem1 = self.lif1.init_synaptic() # reset/init hidden states at t=0
        syn2, mem2 = self.lif2.init_synaptic() # reset/init hidden states at t=0

        spk2_rec = [] # record output spikes
        mem2_rec = [] # record output hidden states

        for step in range(x.size(0)): # loop over time
            cur1 = self.fc1(x[step].flatten(1))
            spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
            cur2 = self.fc2(spk1)
            spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)

            spk2_rec.append(spk2) # record spikes
            mem2_rec.append(mem2) # record membrane

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

# Load the network onto CUDA if available
net = Net().to(device)

The code in the forward() function will only be called once the input argument x is explicitly passed into net.

  • fc1 applies a linear transformation to all input pixels from the POKERDVS dataset;

  • lif1 integrates the weighted input over time, emitting a spike if the threshold condition is met;

  • fc2 applies a linear transformation to the output spikes of lif1;

  • lif2 is another spiking neuron layer, integrating the weighted spikes over time.

A ‘biophysical’ interpretation is that fc1 and fc2 generate current injections that are fed into a set of \(128\) and \(10\) spiking neurons in lif1 and lif2, respectively.

Note: the number of spiking neurons is automatically inferred by the dimensionality of the dimensions of the current injection value.

4. Training the SNN#

4.1 Accuracy Metric#

Below is a function that takes a batch of data, counts up all the spikes from each neuron (i.e., a rate code over the simulation time), and compares the index of the highest count with the actual target. If they match, then the network correctly predicted the target.

def measure_accuracy(model, dataloader):
  with torch.no_grad():
    model.eval()
    running_length = 0
    running_accuracy = 0

    for data, targets in iter(dataloader):
      data = data.to(device)
      targets = targets.to(device)

      # forward-pass
      spk_rec, _ = model(data)
      spike_count = spk_rec.sum(0) # batch x num_outputs
      _, max_spike = spike_count.max(1)

      # correct classes for one batch
      num_correct = (max_spike == targets).sum()

      # total accuracy
      running_length += len(targets)
      running_accuracy += num_correct

    accuracy = (running_accuracy / running_length)

    return accuracy.item()

4.2 Loss Definition#

The nn.CrossEntropyLoss function in PyTorch automatically handles taking the softmax of the output layer as well as generating a loss at the output.

loss = nn.CrossEntropyLoss()

4.3 Optimizer#

Adam is a robust optimizer that performs well on recurrent networks, so let’s use that with a learning rate of \(5\times10^{-4}\).

optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

4.4 One Iteration of Training#

Take the first batch of data and load it onto CUDA if available.

data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

Pass the input data to the network.

spk_rec, mem_rec = net(data)
print(mem_rec.size())

The recording of the membrane potential is taken across:

  • 29 time steps

  • 8 samples of data

  • 4 output neurons

We wish to calculate the loss at every time step, and sum these up together:

\[\mathcal{L}_{Total-CE} = \sum_t\mathcal{L}_{CE}[t]\]
# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(mem_rec.size(0)):
  loss_val += loss(mem_rec[step], targets)

print(f"Training loss: {loss_val.item():.3f}")

The loss is quite large, because it is summed over 29-ish time steps. The accuracy is also bad (it should be roughly around 25%) as the network is untrained:

measure_accuracy(net, train_loader)

A single weight update is applied to the network as follows:

# clear previously stored gradients
optimizer.zero_grad()

# calculate the gradients
loss_val.backward()

# weight update
optimizer.step()

Now, re-run the loss calculation and accuracy after a single iteration:

# calculate new network outputs using the same data
spk_rec, mem_rec = net(data)

# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(mem_rec.size(0)):
  loss_val += loss(mem_rec[step], targets)

print(f"Training loss: {loss_val.item():.3f}")
measure_accuracy(net, train_loader)

After only one iteration, the loss should have decreased and accuracy should have increased. Note how membrane potential is used to calculate the cross entropy loss, and spike count is used for the measure of accuracy. It is also possible to use the spike count in the loss (see Tutorial 6 in the snnTorch docs)

4.5 Training Loop#

Let’s combine everything into a training loop. We will train for one epoch (though feel free to increase num_epochs), exposing our network to each sample of data once.

num_epochs = 10
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(mem_rec.size(0)):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data)

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(test_mem.size(0)):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            # if counter % 50 == 0:
            print(f"Iteration: {counter} \t Accuracy: {measure_accuracy(net, test_loader)}")
            counter += 1
            iter_counter +=1

If this was your first time training an SNN, then congratulations. I’m proud of you and I always believed in you.

measure_accuracy(net, test_loader)

5. Export to NIR#

import nir
nir_model = snn.export_to_nir(net.cpu(), data.cpu())
nir.write("nir_model.nir", nir_model)

6. Run the model with Norse#

6.1 Import NIR model to Norse#

!pip install norse --quiet
import norse.torch as norse
nir_model = nir.read("nir_model.nir")
norse_model = norse.from_nir(nir_model, dt=0.0001) # dt is the simulation step width assumed by snntorch
norse_model

norse.from_nir(..) returns a GraphExecutor object. Its is callable like a nn.Module.

In this case it contains:

  • Two Linear Layers

  • Two CubaLIF layers, each composed of a leaky-integrator and an LIF neuron

  • Identy layers for input and output

The order in which the layers are called, can also be obtained:

print([elem.name for elem in norse_model.get_execution_order()])

6.2. Run the model with a single batch of data#

The graph executor can run a single forward step. Let’s write a function to apply the data for all time steps…

def apply(data):
    """
    apply an input data batch to the norse model
    """
    state = None
    hid_rec = []
    out = []

    for i, t in enumerate(data):
        z, state = norse_model(t.flatten(1), state)
        out.append(z)
        hid_rec.append(state)
    spk_out = torch.stack(out)
    # hid_rec = torch.stack(hid_rec)
    return spk_out, hid_rec

Apply to a batch of data

data, targets = next(iter(test_loader))

# data = data.to(device)

spk, hid = apply(data)

# count the number of spikes for each neuron and assess the winner
predictions = spk.sum(axis=0).argmax(axis=-1)
print(f"Predicted classes: {predictions}")
print(f"Actual classes:    {targets}")

6.3 Measure accuracy for test dataset#

def measure_accuracy2(model, dataloader):
  with torch.no_grad():
    # model.eval()  # not needed!
    running_length = 0
    running_accuracy = 0

    for data, targets in iter(dataloader):
      # data = data.to(device)
      # targets = targets.to(device)

      # forward-pass
      spk_rec, _ = model(data)
      spike_count = spk_rec.sum(0) # batch x num_outputs
      _, max_spike = spike_count.max(1)

      # correct classes for one batch
      num_correct = (max_spike == targets).sum()

      # total accuracy
      running_length += len(targets)
      running_accuracy += num_correct

    accuracy = (running_accuracy / running_length)

    return accuracy.item()
measure_accuracy2(apply, test_loader)
#@title Run this block for a good time
import requests
from IPython.display import Image, display

def display_image_from_url(url):
    response = requests.get(url, stream=True)
    display(Image(response.content))

url = "http://www.clker.com/cliparts/7/8/a/0/1498553633398980412very-nice-borat.med.png"
display_image_from_url(url)

Conclusion#

That covers how to train a spiking neural network, how to convert it into the neuromorphic intermediate representation, and how to load into another pytorch based framework.

There are a lot of ways to alter this, e.g. for SNN training, by using different neuron models, surrogate gradients, learnable beta and threshold values, or modifying the fully-connected layers by replacing them with convolutions or whatever else you fancy.