Spyx is a JAX-based SNN/Deep learning framework that enables fully JIT compiled optimization of models.

import spyx
import spyx.nn as snn

import jax
import jax.numpy as jnp

import nir
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import spyx
      2 import spyx.nn as snn
      4 import jax

ModuleNotFoundError: No module named 'spyx'

Import a NIR graph to Spyx:#

# Load the NIR graph from disk
nir_graph = nir.read("saved_network.nir")

# Use the nir_graph and a sample of your input (for shape information)
# dt is used to scale the weights properly if the imported network was trained
# in a different simulator where dt is not necessarily 1.
SNN, params = spyx.nir.from_nir(nir_graph, sample_batch, dt=1)

# Use it as you wish:
SNN.apply(params, sample_batch)

Export a network from Spyx to a NIR graph:#

# Some operations may have rearranged the PyTree (dictionary) that stores
# the SNN weights, so the helper function reorders the dict
# to allow for proper exportation. 
export_params = spyx.nir.reorder_layers(init_params, optimized_params)

# provide the params to export along with the input/output sizes and the desired
# time resolution; this is so you can load it up with the proper dt in other
# frameworks that allow you to specify smaller time intervals
# whereas Spyx assumes every timestep to be 1 to avoid units.
nir_graph = spyx.nir.to_nir(export_params, input_shape, output_shape, dt)

# Write the NIR graph to the desired filepath
nir.write("./spyx_shd.nir", nir_graph)