snnTorch to Norse (Sim2Sim)#
NIR for deep Spiking Neural Networks - From snnTorch to Norse#
Written by Jason Eshraghian and Bernhard Vogginger#
What you will learn:
Learn how spiking neurons are implemented as a recurrent network
Download event-based data and train a spiking neural network with it
Export it to the neuromorphic intermediate representation
Import it to Norse and run inference
Install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing Shift+Enter
.
1. Imports#
!pip install snntorch --quiet
!pip install tonic --quiet
# imports
import snntorch as snn
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
2. Handling Event-based Data with Tonic#
2.1 PokerDVS Dataset#
The dataset used in this tutorial is POKERDVS by T. Serrano-Gotarredona and B. Linares-Barranco:
Serrano-Gotarredona, Teresa, and Bernabé Linares-Barranco. "Poker-DVS and MNIST-DVS. Their history, how they were made, and other details." Frontiers in neuroscience 9 (2015): 481.
It is comprised of four classes, each being a suite of a playing card deck: clubs, spades, hearts, and diamonds. The data consists of 131 poker pip symbols, and was collected by flipping poker cards in front of a DVS128 camera.
import tonic
poker_train = tonic.datasets.POKERDVS(save_to='./data', train=True)
poker_test = tonic.datasets.POKERDVS(save_to='./data', train=False)
events, target = poker_train[0]
print(events)
tonic.utils.plot_event_grid(events)
import tonic.transforms as transforms
from tonic import DiskCachedDataset
# time_window
frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=10000),
tonic.transforms.ToFrame(
sensor_size=tonic.datasets.POKERDVS.sensor_size,
time_window=1000)
])
batch_size = 8
cached_trainset = DiskCachedDataset(poker_train, transform=frame_transform, cache_path='./cache/pokerdvs/train')
cached_testset = DiskCachedDataset(poker_test, transform=frame_transform, cache_path='./cache/pokerdvs/test')
train_loader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
data, labels = next(iter(train_loader))
print(data.size())
print(labels)
3. Define the SNN#
num_inputs = 35*35*2
num_hidden = 128
num_outputs = 4
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
In the following code-block, note how the decay rate beta
has two alternative definitions:
beta1
is set to a global decay rate for all neurons in the first spiking layer.beta2
is randomly initialized to a vector of 10 different numbers. Each spiking neuron in the output layer (which not-so-coincidentally has 10 neurons) therefore has a unique, and random, decay rate.
# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()
alpha1 = 0.5
beta1 = 0.9 # global decay rate for all leaky neurons in layer 1
beta2 = torch.rand((num_outputs), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)
threshold2 = torch.ones_like(beta2) # threshold parameter must have the same shape as beta for NIR
alpha2 = torch.ones_like(beta2)*0.9
# Initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.lif1 = snn.Synaptic(alpha=alpha1, beta=beta1) # not a learnable decay rate
self.fc2 = nn.Linear(num_hidden, num_outputs)
self.lif2 = snn.Synaptic(alpha=alpha2, beta=beta2, threshold=threshold2, learn_beta=True) # learnable decay rate
def forward(self, x):
syn1, mem1 = self.lif1.init_synaptic() # reset/init hidden states at t=0
syn2, mem2 = self.lif2.init_synaptic() # reset/init hidden states at t=0
spk2_rec = [] # record output spikes
mem2_rec = [] # record output hidden states
for step in range(x.size(0)): # loop over time
cur1 = self.fc1(x[step].flatten(1))
spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
cur2 = self.fc2(spk1)
spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)
spk2_rec.append(spk2) # record spikes
mem2_rec.append(mem2) # record membrane
return torch.stack(spk2_rec), torch.stack(mem2_rec)
# Load the network onto CUDA if available
net = Net().to(device)
The code in the forward()
function will only be called once the input argument x
is explicitly passed into net
.
fc1
applies a linear transformation to all input pixels from the POKERDVS dataset;lif1
integrates the weighted input over time, emitting a spike if the threshold condition is met;fc2
applies a linear transformation to the output spikes oflif1
;lif2
is another spiking neuron layer, integrating the weighted spikes over time.
A ‘biophysical’ interpretation is that fc1
and fc2
generate current injections that are fed into a set of \(128\) and \(10\) spiking neurons in lif1
and lif2
, respectively.
Note: the number of spiking neurons is automatically inferred by the dimensionality of the dimensions of the current injection value.
4. Training the SNN#
4.1 Accuracy Metric#
Below is a function that takes a batch of data, counts up all the spikes from each neuron (i.e., a rate code over the simulation time), and compares the index of the highest count with the actual target. If they match, then the network correctly predicted the target.
def measure_accuracy(model, dataloader):
with torch.no_grad():
model.eval()
running_length = 0
running_accuracy = 0
for data, targets in iter(dataloader):
data = data.to(device)
targets = targets.to(device)
# forward-pass
spk_rec, _ = model(data)
spike_count = spk_rec.sum(0) # batch x num_outputs
_, max_spike = spike_count.max(1)
# correct classes for one batch
num_correct = (max_spike == targets).sum()
# total accuracy
running_length += len(targets)
running_accuracy += num_correct
accuracy = (running_accuracy / running_length)
return accuracy.item()
4.2 Loss Definition#
The nn.CrossEntropyLoss
function in PyTorch automatically handles taking the softmax of the output layer as well as generating a loss at the output.
loss = nn.CrossEntropyLoss()
4.3 Optimizer#
Adam is a robust optimizer that performs well on recurrent networks, so let’s use that with a learning rate of \(5\times10^{-4}\).
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
4.4 One Iteration of Training#
Take the first batch of data and load it onto CUDA if available.
data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)
Pass the input data to the network.
spk_rec, mem_rec = net(data)
print(mem_rec.size())
The recording of the membrane potential is taken across:
29 time steps
8 samples of data
4 output neurons
We wish to calculate the loss at every time step, and sum these up together:
# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)
# sum loss at every step
for step in range(mem_rec.size(0)):
loss_val += loss(mem_rec[step], targets)
print(f"Training loss: {loss_val.item():.3f}")
The loss is quite large, because it is summed over 29-ish time steps. The accuracy is also bad (it should be roughly around 25%) as the network is untrained:
measure_accuracy(net, train_loader)
A single weight update is applied to the network as follows:
# clear previously stored gradients
optimizer.zero_grad()
# calculate the gradients
loss_val.backward()
# weight update
optimizer.step()
Now, re-run the loss calculation and accuracy after a single iteration:
# calculate new network outputs using the same data
spk_rec, mem_rec = net(data)
# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)
# sum loss at every step
for step in range(mem_rec.size(0)):
loss_val += loss(mem_rec[step], targets)
print(f"Training loss: {loss_val.item():.3f}")
measure_accuracy(net, train_loader)
After only one iteration, the loss should have decreased and accuracy should have increased. Note how membrane potential is used to calculate the cross entropy loss, and spike count is used for the measure of accuracy. It is also possible to use the spike count in the loss (see Tutorial 6 in the snnTorch docs)
4.5 Training Loop#
Let’s combine everything into a training loop. We will train for one epoch (though feel free to increase num_epochs
), exposing our network to each sample of data once.
num_epochs = 10
loss_hist = []
test_loss_hist = []
counter = 0
# Outer training loop
for epoch in range(num_epochs):
iter_counter = 0
train_batch = iter(train_loader)
# Minibatch training loop
for data, targets in train_batch:
data = data.to(device)
targets = targets.to(device)
# forward pass
net.train()
spk_rec, mem_rec = net(data)
# initialize the loss & sum over time
loss_val = torch.zeros((1), dtype=dtype, device=device)
for step in range(mem_rec.size(0)):
loss_val += loss(mem_rec[step], targets)
# Gradient calculation + weight update
optimizer.zero_grad()
loss_val.backward()
optimizer.step()
# Store loss history for future plotting
loss_hist.append(loss_val.item())
# Test set
with torch.no_grad():
net.eval()
test_data, test_targets = next(iter(test_loader))
test_data = test_data.to(device)
test_targets = test_targets.to(device)
# Test set forward pass
test_spk, test_mem = net(test_data)
# Test set loss
test_loss = torch.zeros((1), dtype=dtype, device=device)
for step in range(test_mem.size(0)):
test_loss += loss(test_mem[step], test_targets)
test_loss_hist.append(test_loss.item())
# Print train/test loss/accuracy
# if counter % 50 == 0:
print(f"Iteration: {counter} \t Accuracy: {measure_accuracy(net, test_loader)}")
counter += 1
iter_counter +=1
If this was your first time training an SNN, then congratulations. I’m proud of you and I always believed in you.
measure_accuracy(net, test_loader)
5. Export to NIR#
import nir
nir_model = snn.export_to_nir(net.cpu(), data.cpu())
nir.write("nir_model.nir", nir_model)
6. Run the model with Norse#
6.1 Import NIR model to Norse#
!pip install norse --quiet
import norse.torch as norse
nir_model = nir.read("nir_model.nir")
norse_model = norse.from_nir(nir_model, dt=0.0001) # dt is the simulation step width assumed by snntorch
norse_model
norse.from_nir(..)
returns a GraphExecutor
object. Its is callable like a nn.Module
.
In this case it contains:
Two Linear Layers
Two CubaLIF layers, each composed of a leaky-integrator and an LIF neuron
Identy layers for input and output
The order in which the layers are called, can also be obtained:
print([elem.name for elem in norse_model.get_execution_order()])
6.2. Run the model with a single batch of data#
The graph executor can run a single forward step. Let’s write a function to apply the data for all time steps…
def apply(data):
"""
apply an input data batch to the norse model
"""
state = None
hid_rec = []
out = []
for i, t in enumerate(data):
z, state = norse_model(t.flatten(1), state)
out.append(z)
hid_rec.append(state)
spk_out = torch.stack(out)
# hid_rec = torch.stack(hid_rec)
return spk_out, hid_rec
Apply to a batch of data
data, targets = next(iter(test_loader))
# data = data.to(device)
spk, hid = apply(data)
# count the number of spikes for each neuron and assess the winner
predictions = spk.sum(axis=0).argmax(axis=-1)
print(f"Predicted classes: {predictions}")
print(f"Actual classes: {targets}")
6.3 Measure accuracy for test dataset#
def measure_accuracy2(model, dataloader):
with torch.no_grad():
# model.eval() # not needed!
running_length = 0
running_accuracy = 0
for data, targets in iter(dataloader):
# data = data.to(device)
# targets = targets.to(device)
# forward-pass
spk_rec, _ = model(data)
spike_count = spk_rec.sum(0) # batch x num_outputs
_, max_spike = spike_count.max(1)
# correct classes for one batch
num_correct = (max_spike == targets).sum()
# total accuracy
running_length += len(targets)
running_accuracy += num_correct
accuracy = (running_accuracy / running_length)
return accuracy.item()
measure_accuracy2(apply, test_loader)
#@title Run this block for a good time
import requests
from IPython.display import Image, display
def display_image_from_url(url):
response = requests.get(url, stream=True)
display(Image(response.content))
url = "http://www.clker.com/cliparts/7/8/a/0/1498553633398980412very-nice-borat.med.png"
display_image_from_url(url)
Conclusion#
That covers how to train a spiking neural network, how to convert it into the neuromorphic intermediate representation, and how to load into another pytorch based framework.
There are a lot of ways to alter this, e.g. for SNN training, by using different neuron models, surrogate gradients, learnable beta and threshold values, or modifying the fully-connected layers by replacing them with convolutions or whatever else you fancy.