Table of Contents
Fetching ...

SNNAX -- Spiking Neural Networks in JAX

Jamie Lohoff, Jan Finkbeiner, Emre Neftci

TL;DR

SNNAX addresses the need for fast, flexible, and differentiable simulation of Spiking Neural Networks by building a JAX-based library on top of Equinox that offers PyTorch-like usability and XLA-powered performance. It provides modular connectivity, stateful neuron dynamics, and gradient-based learning built with JAX AD, including surrogate gradients and temporal backpropagation through time. The framework emphasizes a PyTree-centered UI, layer abstractions, and graph-based execution to support both feed-forward and recurrent SNNs without heavy CUDA boilerplate. By combining stateful layers, advanced AD tools, and a robust UI, SNNAX aims to accelerate algorithmic exploration, rapid prototyping, and efficient deployment on GPUs/TPUs for neuromorphic research and hardware-in-the-loop experiments.

Abstract

Spiking Neural Networks (SNNs) simulators are essential tools to prototype biologically inspired models and neuromorphic hardware architectures and predict their performance. For such a tool, ease of use and flexibility are critical, but so is simulation speed especially given the complexity inherent to simulating SNN. Here, we present SNNAX, a JAX-based framework for simulating and training such models with PyTorch-like intuitiveness and JAX-like execution speed. SNNAX models are easily extended and customized to fit the desired model specifications and target neuromorphic hardware. Additionally, SNNAX offers key features for optimizing the training and deployment of SNNs such as flexible automatic differentiation and just-in-time compilation. We evaluate and compare SNNAX to other commonly used machine learning (ML) frameworks used for programming SNNs. We provide key performance metrics, best practices, documented examples for simulating SNNs in SNNAX, and implement several benchmarks used in the literature.

SNNAX -- Spiking Neural Networks in JAX

TL;DR

SNNAX addresses the need for fast, flexible, and differentiable simulation of Spiking Neural Networks by building a JAX-based library on top of Equinox that offers PyTorch-like usability and XLA-powered performance. It provides modular connectivity, stateful neuron dynamics, and gradient-based learning built with JAX AD, including surrogate gradients and temporal backpropagation through time. The framework emphasizes a PyTree-centered UI, layer abstractions, and graph-based execution to support both feed-forward and recurrent SNNs without heavy CUDA boilerplate. By combining stateful layers, advanced AD tools, and a robust UI, SNNAX aims to accelerate algorithmic exploration, rapid prototyping, and efficient deployment on GPUs/TPUs for neuromorphic research and hardware-in-the-loop experiments.

Abstract

Spiking Neural Networks (SNNs) simulators are essential tools to prototype biologically inspired models and neuromorphic hardware architectures and predict their performance. For such a tool, ease of use and flexibility are critical, but so is simulation speed especially given the complexity inherent to simulating SNN. Here, we present SNNAX, a JAX-based framework for simulating and training such models with PyTorch-like intuitiveness and JAX-like execution speed. SNNAX models are easily extended and customized to fit the desired model specifications and target neuromorphic hardware. Additionally, SNNAX offers key features for optimizing the training and deployment of SNNs such as flexible automatic differentiation and just-in-time compilation. We evaluate and compare SNNAX to other commonly used machine learning (ML) frameworks used for programming SNNs. We provide key performance metrics, best practices, documented examples for simulating SNNs in SNNAX, and implement several benchmarks used in the literature.
Paper Structure (11 sections, 1 equation, 2 figures)

This paper contains 11 sections, 1 equation, 2 figures.

Figures (2)

  • Figure 1: Execution time measurements of a two-layer MLP with blocks of 2048 LIF neurons for batchsize 32.
  • Figure 2: Execution time measurements of a two-layer LIF-CNN with kernel size 3 on a 48x48 pixel image and batchsize 32.