{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "8R08u5xKHBv6" }, "source": [ "# snnTorch to Norse (Sim2Sim)\n", "\n", "## NIR for deep Spiking Neural Networks - From snnTorch to Norse\n", "### Written by Jason Eshraghian and Bernhard Vogginger\n", "\n", "What you will learn:\n", "* Learn how spiking neurons are implemented as a recurrent network\n", "* Download event-based data and train a spiking neural network with it\n", "* Export it to the neuromorphic intermediate representation\n", "* Import it to Norse and run inference\n", "\n", "Install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "H_3jC4pJ8xzO" }, "source": [ "## 1. Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l5rPCjp5OIhf" }, "outputs": [], "source": [ "!pip install snntorch --quiet\n", "!pip install tonic --quiet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kh4NW-mc0JaY" }, "outputs": [], "source": [ "# imports\n", "import snntorch as snn\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms" ] }, { "cell_type": "markdown", "metadata": { "id": "4lvCNRHbOGW7" }, "source": [ "# 2. Handling Event-based Data with Tonic" ] }, { "cell_type": "markdown", "metadata": { "id": "nhJunMSEOVMk" }, "source": [ "## 2.1 PokerDVS Dataset\n", "\n", "The dataset used in this tutorial is POKERDVS by T. Serrano-Gotarredona and B. Linares-Barranco:\n", "\n", "```\n", "Serrano-Gotarredona, Teresa, and Bernabé Linares-Barranco. \"Poker-DVS and MNIST-DVS. Their history, how they were made, and other details.\" Frontiers in neuroscience 9 (2015): 481.\n", "```\n", "\n", "It is comprised of four classes, each being a suite of a playing card deck: clubs, spades, hearts, and diamonds. The data consists of 131 poker pip symbols, and was collected by flipping poker cards in front of a DVS128 camera." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mM4cKrtmOgPS" }, "outputs": [], "source": [ "import tonic\n", "\n", "poker_train = tonic.datasets.POKERDVS(save_to='./data', train=True)\n", "poker_test = tonic.datasets.POKERDVS(save_to='./data', train=False)\n", "\n", "events, target = poker_train[0]\n", "print(events)\n", "tonic.utils.plot_event_grid(events)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cGgrBCjqOhpo" }, "outputs": [], "source": [ "import tonic.transforms as transforms\n", "from tonic import DiskCachedDataset\n", "\n", "# time_window\n", "frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=10000),\n", " tonic.transforms.ToFrame(\n", " sensor_size=tonic.datasets.POKERDVS.sensor_size,\n", " time_window=1000)\n", " ])\n", "\n", "batch_size = 8\n", "cached_trainset = DiskCachedDataset(poker_train, transform=frame_transform, cache_path='./cache/pokerdvs/train')\n", "cached_testset = DiskCachedDataset(poker_test, transform=frame_transform, cache_path='./cache/pokerdvs/test')\n", "\n", "train_loader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n", "test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n", "\n", "data, labels = next(iter(train_loader))\n", "print(data.size())\n", "print(labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "IVacJHwO6t4M" }, "source": [ "## 3. Define the SNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SmclKI7d62Oc" }, "outputs": [], "source": [ "num_inputs = 35*35*2\n", "num_hidden = 128\n", "num_outputs = 4" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CWdyYErq7zXe" }, "outputs": [], "source": [ "dtype = torch.float\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hH8nHOsZ9Hxo" }, "source": [ "In the following code-block, note how the decay rate `beta` has two alternative definitions:\n", "* `beta1` is set to a global decay rate for all neurons in the first spiking layer.\n", "* `beta2` is randomly initialized to a vector of 10 different numbers. Each spiking neuron in the output layer (which not-so-coincidentally has 10 neurons) therefore has a unique, and random, decay rate." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7rEX7-U687zV" }, "outputs": [], "source": [ "# Define Network\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " alpha1 = 0.5\n", " beta1 = 0.9 # global decay rate for all leaky neurons in layer 1\n", " beta2 = torch.rand((num_outputs), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)\n", " threshold2 = torch.ones_like(beta2) # threshold parameter must have the same shape as beta for NIR\n", " alpha2 = torch.ones_like(beta2)*0.9\n", "\n", " # Initialize layers\n", " self.fc1 = nn.Linear(num_inputs, num_hidden)\n", " self.lif1 = snn.Synaptic(alpha=alpha1, beta=beta1) # not a learnable decay rate\n", " self.fc2 = nn.Linear(num_hidden, num_outputs)\n", " self.lif2 = snn.Synaptic(alpha=alpha2, beta=beta2, threshold=threshold2, learn_beta=True) # learnable decay rate\n", "\n", " def forward(self, x):\n", " syn1, mem1 = self.lif1.init_synaptic() # reset/init hidden states at t=0\n", " syn2, mem2 = self.lif2.init_synaptic() # reset/init hidden states at t=0\n", "\n", " spk2_rec = [] # record output spikes\n", " mem2_rec = [] # record output hidden states\n", "\n", " for step in range(x.size(0)): # loop over time\n", " cur1 = self.fc1(x[step].flatten(1))\n", " spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)\n", " cur2 = self.fc2(spk1)\n", " spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)\n", "\n", " spk2_rec.append(spk2) # record spikes\n", " mem2_rec.append(mem2) # record membrane\n", "\n", " return torch.stack(spk2_rec), torch.stack(mem2_rec)\n", "\n", "# Load the network onto CUDA if available\n", "net = Net().to(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "3le9vSo29ocU" }, "source": [ "The code in the `forward()` function will only be called once the input argument `x` is explicitly passed into `net`.\n", "\n", "* `fc1` applies a linear transformation to all input pixels from the POKERDVS dataset;\n", "* `lif1` integrates the weighted input over time, emitting a spike if the threshold condition is met;\n", "* `fc2` applies a linear transformation to the output spikes of `lif1`;\n", "* `lif2` is another spiking neuron layer, integrating the weighted spikes over time.\n", "\n", "A 'biophysical' interpretation is that `fc1` and `fc2` generate current injections that are fed into a set of $128$ and $10$ spiking neurons in `lif1` and `lif2`, respectively.\n", "\n", "> Note: the number of spiking neurons is automatically inferred by the dimensionality of the dimensions of the current injection value." ] }, { "cell_type": "markdown", "metadata": { "id": "L-ekyCxt78qB" }, "source": [ "# 4. Training the **SNN**" ] }, { "cell_type": "markdown", "metadata": { "id": "1zKiDeSM1onU" }, "source": [ "## 4.1 Accuracy Metric\n", "Below is a function that takes a batch of data, counts up all the spikes from each neuron (i.e., a rate code over the simulation time), and compares the index of the highest count with the actual target. If they match, then the network correctly predicted the target." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eytnSVQ8OtBj" }, "outputs": [], "source": [ "def measure_accuracy(model, dataloader):\n", " with torch.no_grad():\n", " model.eval()\n", " running_length = 0\n", " running_accuracy = 0\n", "\n", " for data, targets in iter(dataloader):\n", " data = data.to(device)\n", " targets = targets.to(device)\n", "\n", " # forward-pass\n", " spk_rec, _ = model(data)\n", " spike_count = spk_rec.sum(0) # batch x num_outputs\n", " _, max_spike = spike_count.max(1)\n", "\n", " # correct classes for one batch\n", " num_correct = (max_spike == targets).sum()\n", "\n", " # total accuracy\n", " running_length += len(targets)\n", " running_accuracy += num_correct\n", "\n", " accuracy = (running_accuracy / running_length)\n", "\n", " return accuracy.item()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "k1Fzf-3S3WGp" }, "source": [ "## 4.2 Loss Definition\n", "The `nn.CrossEntropyLoss` function in PyTorch automatically handles taking the softmax of the output layer as well as generating a loss at the output." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mK0xRQQp3ypC" }, "outputs": [], "source": [ "loss = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "metadata": { "id": "wr4TTQ204ubI" }, "source": [ "## 4.3 Optimizer\n", "Adam is a robust optimizer that performs well on recurrent networks, so let's use that with a learning rate of $5\\times10^{-4}$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "muBPATWo40pI" }, "outputs": [], "source": [ "optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))" ] }, { "cell_type": "markdown", "metadata": { "id": "IKTBRtsX-ZNy" }, "source": [ "## 4.4 One Iteration of Training\n", "Take the first batch of data and load it onto CUDA if available." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e323axy244wa" }, "outputs": [], "source": [ "data, targets = next(iter(train_loader))\n", "data = data.to(device)\n", "targets = targets.to(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZZ1zApYM-cRk" }, "source": [ "Pass the input data to the network." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J-w2w7ME-d2o" }, "outputs": [], "source": [ "spk_rec, mem_rec = net(data)\n", "print(mem_rec.size())" ] }, { "cell_type": "markdown", "metadata": { "id": "AIikA-Rh-fSa" }, "source": [ "The recording of the membrane potential is taken across:\n", "* 29 time steps\n", "* 8 samples of data\n", "* 4 output neurons\n", "\n", "We wish to calculate the loss at every time step, and sum these up together:\n", "\n", "\n", "$$\\mathcal{L}_{Total-CE} = \\sum_t\\mathcal{L}_{CE}[t]$$" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WHlczZEb-vqj" }, "outputs": [], "source": [ "# initialize the total loss value\n", "loss_val = torch.zeros((1), dtype=dtype, device=device)\n", "\n", "# sum loss at every step\n", "for step in range(mem_rec.size(0)):\n", " loss_val += loss(mem_rec[step], targets)\n", "\n", "print(f\"Training loss: {loss_val.item():.3f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "InrjD9n--xCt" }, "source": [ "The loss is quite large, because it is summed over 29-ish time steps. The accuracy is also bad (it should be roughly around 25%) as the network is untrained:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LaDDSPLK9u7p" }, "outputs": [], "source": [ "measure_accuracy(net, train_loader)" ] }, { "cell_type": "markdown", "metadata": { "id": "0dvil_l3-09y" }, "source": [ "A single weight update is applied to the network as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7Okzx-3p-0Mt" }, "outputs": [], "source": [ "# clear previously stored gradients\n", "optimizer.zero_grad()\n", "\n", "# calculate the gradients\n", "loss_val.backward()\n", "\n", "# weight update\n", "optimizer.step()" ] }, { "cell_type": "markdown", "metadata": { "id": "iHbLCKsa-3Ao" }, "source": [ "Now, re-run the loss calculation and accuracy after a single iteration:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cvx9Ux2--4Aw" }, "outputs": [], "source": [ "# calculate new network outputs using the same data\n", "spk_rec, mem_rec = net(data)\n", "\n", "# initialize the total loss value\n", "loss_val = torch.zeros((1), dtype=dtype, device=device)\n", "\n", "# sum loss at every step\n", "for step in range(mem_rec.size(0)):\n", " loss_val += loss(mem_rec[step], targets)\n", "\n", "print(f\"Training loss: {loss_val.item():.3f}\")\n", "measure_accuracy(net, train_loader)" ] }, { "cell_type": "markdown", "metadata": { "id": "7T3MCnxo-5fE" }, "source": [ "After only one iteration, the loss should have decreased and accuracy should have increased. Note how membrane potential is used to calculate the cross entropy loss, and spike count is used for the measure of accuracy. It is also possible to use the spike count in the loss ([see Tutorial 6 in the snnTorch docs](https://snntorch.readthedocs.io/en/latest/tutorials/index.html))" ] }, { "cell_type": "markdown", "metadata": { "id": "hDrQ7oMM--dU" }, "source": [ "## 4.5 Training Loop\n", "\n", "Let's combine everything into a training loop. We will train for one epoch (though feel free to increase `num_epochs`), exposing our network to each sample of data once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UCnjeX4U_AVD" }, "outputs": [], "source": [ "num_epochs = 10\n", "loss_hist = []\n", "test_loss_hist = []\n", "counter = 0\n", "\n", "# Outer training loop\n", "for epoch in range(num_epochs):\n", " iter_counter = 0\n", " train_batch = iter(train_loader)\n", "\n", " # Minibatch training loop\n", " for data, targets in train_batch:\n", " data = data.to(device)\n", " targets = targets.to(device)\n", "\n", " # forward pass\n", " net.train()\n", " spk_rec, mem_rec = net(data)\n", "\n", " # initialize the loss & sum over time\n", " loss_val = torch.zeros((1), dtype=dtype, device=device)\n", " for step in range(mem_rec.size(0)):\n", " loss_val += loss(mem_rec[step], targets)\n", "\n", " # Gradient calculation + weight update\n", " optimizer.zero_grad()\n", " loss_val.backward()\n", " optimizer.step()\n", "\n", " # Store loss history for future plotting\n", " loss_hist.append(loss_val.item())\n", "\n", " # Test set\n", " with torch.no_grad():\n", " net.eval()\n", " test_data, test_targets = next(iter(test_loader))\n", " test_data = test_data.to(device)\n", " test_targets = test_targets.to(device)\n", "\n", " # Test set forward pass\n", " test_spk, test_mem = net(test_data)\n", "\n", " # Test set loss\n", " test_loss = torch.zeros((1), dtype=dtype, device=device)\n", " for step in range(test_mem.size(0)):\n", " test_loss += loss(test_mem[step], test_targets)\n", " test_loss_hist.append(test_loss.item())\n", "\n", " # Print train/test loss/accuracy\n", " # if counter % 50 == 0:\n", " print(f\"Iteration: {counter} \\t Accuracy: {measure_accuracy(net, test_loader)}\")\n", " counter += 1\n", " iter_counter +=1" ] }, { "cell_type": "markdown", "metadata": { "id": "MWXOCBNI_B47" }, "source": [ "If this was your first time training an SNN, then congratulations. I'm proud of you and I always believed in you." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4pAJ9_WsFeSB" }, "outputs": [], "source": [ "measure_accuracy(net, test_loader)" ] }, { "cell_type": "markdown", "metadata": { "id": "7BZpywK3_LZL" }, "source": [ "# 5. Export to NIR" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yJ2ZcBH-ARmt" }, "outputs": [], "source": [ "import nir" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qvtZS8kW_euK" }, "outputs": [], "source": [ "nir_model = snn.export_to_nir(net.cpu(), data.cpu())\n", "nir.write(\"nir_model.nir\", nir_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "VFq1hIbkFeSC" }, "source": [ "# 6. Run the model with Norse" ] }, { "cell_type": "markdown", "metadata": { "id": "0-xAVat6FeSD" }, "source": [ "## 6.1 Import NIR model to Norse" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "doW_7eauFeSD" }, "outputs": [], "source": [ "!pip install norse --quiet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_ITH3yqMFeSD" }, "outputs": [], "source": [ "import norse.torch as norse" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "la-3vqRmFeSD" }, "outputs": [], "source": [ "nir_model = nir.read(\"nir_model.nir\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iq0JrUPJFeSD" }, "outputs": [], "source": [ "norse_model = norse.from_nir(nir_model, dt=0.0001) # dt is the simulation step width assumed by snntorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zI5ZfDJuKqCf" }, "outputs": [], "source": [ "norse_model" ] }, { "cell_type": "markdown", "metadata": { "id": "S56YqwRYFeSD" }, "source": [ "`norse.from_nir(..)` returns a `GraphExecutor` object. Its is callable like a `nn.Module`.\n", "\n", "In this case it contains:\n", "- Two Linear Layers\n", "- Two CubaLIF layers, each composed of a leaky-integrator and an LIF neuron\n", "- Identy layers for input and output\n", "\n", "The order in which the layers are called, can also be obtained:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e6LvAiH6MbnX" }, "outputs": [], "source": [ "print([elem.name for elem in norse_model.get_execution_order()])" ] }, { "cell_type": "markdown", "metadata": { "id": "r6HSJq7LFeSD" }, "source": [ "## 6.2. Run the model with a single batch of data\n", "\n", "The graph executor can run a single forward step. Let's write a function to apply the data for all time steps..." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pnqTzuWvFeSD" }, "outputs": [], "source": [ "def apply(data):\n", " \"\"\"\n", " apply an input data batch to the norse model\n", " \"\"\"\n", " state = None\n", " hid_rec = []\n", " out = []\n", "\n", " for i, t in enumerate(data):\n", " z, state = norse_model(t.flatten(1), state)\n", " out.append(z)\n", " hid_rec.append(state)\n", " spk_out = torch.stack(out)\n", " # hid_rec = torch.stack(hid_rec)\n", " return spk_out, hid_rec" ] }, { "cell_type": "markdown", "metadata": { "id": "QCL4EIZQFeSE" }, "source": [ "Apply to a batch of data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5Z6MVgU6FeSE" }, "outputs": [], "source": [ "data, targets = next(iter(test_loader))\n", "\n", "# data = data.to(device)\n", "\n", "spk, hid = apply(data)\n", "\n", "# count the number of spikes for each neuron and assess the winner\n", "predictions = spk.sum(axis=0).argmax(axis=-1)\n", "print(f\"Predicted classes: {predictions}\")\n", "print(f\"Actual classes: {targets}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x3eGShtIFeSE" }, "source": [ "### 6.3 Measure accuracy for test dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w8pgFPB5FeSE" }, "outputs": [], "source": [ "def measure_accuracy2(model, dataloader):\n", " with torch.no_grad():\n", " # model.eval() # not needed!\n", " running_length = 0\n", " running_accuracy = 0\n", "\n", " for data, targets in iter(dataloader):\n", " # data = data.to(device)\n", " # targets = targets.to(device)\n", "\n", " # forward-pass\n", " spk_rec, _ = model(data)\n", " spike_count = spk_rec.sum(0) # batch x num_outputs\n", " _, max_spike = spike_count.max(1)\n", "\n", " # correct classes for one batch\n", " num_correct = (max_spike == targets).sum()\n", "\n", " # total accuracy\n", " running_length += len(targets)\n", " running_accuracy += num_correct\n", "\n", " accuracy = (running_accuracy / running_length)\n", "\n", " return accuracy.item()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PDkVh1maFeSE" }, "outputs": [], "source": [ "measure_accuracy2(apply, test_loader)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "NaTLN3HN_7R0" }, "outputs": [], "source": [ "#@title Run this block for a good time\n", "import requests\n", "from IPython.display import Image, display\n", "\n", "def display_image_from_url(url):\n", " response = requests.get(url, stream=True)\n", " display(Image(response.content))\n", "\n", "url = \"http://www.clker.com/cliparts/7/8/a/0/1498553633398980412very-nice-borat.med.png\"\n", "display_image_from_url(url)" ] }, { "cell_type": "markdown", "metadata": { "id": "3SlDUImz6Q2e" }, "source": [ "# Conclusion\n", "\n", "That covers how to train a spiking neural network, how to convert it into the neuromorphic intermediate representation, and how to load into another pytorch based framework.\n", "\n", "There are a lot of ways to alter this, e.g. for SNN training, by using different neuron models, surrogate gradients, learnable beta and threshold values, or modifying the fully-connected layers by replacing them with convolutions or whatever else you fancy." ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "include_colab_link": true, "provenance": [] }, "kernelspec": { "display_name": "Python [conda env:nir-demo] *", "language": "python", "name": "conda-env-nir-demo-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 0 }