Table of Contents
Fetching ...

Slax: A Composable JAX Library for Rapid and Flexible Prototyping of Spiking Neural Networks

Thomas M. Summe, Siddharth Joshi

TL;DR

The paper addresses the need for rapid, flexible experimentation with spiking neural network learning algorithms beyond backpropagation through time with surrogate gradients. It introduces Slax, a JAX/Flax-based library with modular LIF neurons, a connect API for recurrent architectures, and a suite of learning rules (OSTL, OTTT, OTPE, RTRL, FPTT) plus online/offline training and visualization tools. The authors demonstrate competitive performance relative to other SNN frameworks and emphasize easy integration with Flax and the broader JAX ecosystem. This framework enables researchers to compare learning rules, study gradient behavior, and accelerate prototyping of energy-efficient SNN algorithms, with practical impact on SNN research workflows.

Abstract

Recent advances to algorithms for training spiking neural networks (SNNs) often leverage their unique dynamics. While backpropagation through time (BPTT) with surrogate gradients dominate the field, a rich landscape of alternatives can situate algorithms across various points in the performance, bio-plausibility, and complexity landscape. Evaluating and comparing algorithms is currently a cumbersome and error-prone process, requiring them to be repeatedly re-implemented. We introduce Slax, a JAX-based library designed to accelerate SNN algorithm design, compatible with the broader JAX and Flax ecosystem. Slax provides optimized implementations of diverse training algorithms, allowing direct performance comparison. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training.

Slax: A Composable JAX Library for Rapid and Flexible Prototyping of Spiking Neural Networks

TL;DR

The paper addresses the need for rapid, flexible experimentation with spiking neural network learning algorithms beyond backpropagation through time with surrogate gradients. It introduces Slax, a JAX/Flax-based library with modular LIF neurons, a connect API for recurrent architectures, and a suite of learning rules (OSTL, OTTT, OTPE, RTRL, FPTT) plus online/offline training and visualization tools. The authors demonstrate competitive performance relative to other SNN frameworks and emphasize easy integration with Flax and the broader JAX ecosystem. This framework enables researchers to compare learning rules, study gradient behavior, and accelerate prototyping of energy-efficient SNN algorithms, with practical impact on SNN research workflows.

Abstract

Recent advances to algorithms for training spiking neural networks (SNNs) often leverage their unique dynamics. While backpropagation through time (BPTT) with surrogate gradients dominate the field, a rich landscape of alternatives can situate algorithms across various points in the performance, bio-plausibility, and complexity landscape. Evaluating and comparing algorithms is currently a cumbersome and error-prone process, requiring them to be repeatedly re-implemented. We introduce Slax, a JAX-based library designed to accelerate SNN algorithm design, compatible with the broader JAX and Flax ecosystem. Slax provides optimized implementations of diverse training algorithms, allowing direct performance comparison. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training.
Paper Structure (9 sections, 11 figures, 1 table)

This paper contains 9 sections, 11 figures, 1 table.

Figures (11)

  • Figure 1: Diagram of Slax components. Figure (a) shows the library's structure while figure (b) shows the design and implemented components. While datasets and dataloaders may come from other packages such as tensorflow-datasets TFDS, Slax directly provides NeuroBench datasets and the Randman synthetic dataset. All composable and differentiable functions should use JAX-based libraries or pure Python, but some neuron layers, surrogate derivatives, and NeuroBench models are easily accessible to the user. The same applies to defining a training loop, where Slax provides custom learning rules and easy-to-use training loops for online and offline learning. Evaluating networks is open-ended, but Slax provides utilities for loss landscapes and comparing gradients, which are useful for comparing learning rules.
  • Figure 2: The code above defines a Slax LIF neuron in the style neurons are written in the library. All neurons are implemented with extra logic for handling an explicit state to match Flax RNNs while inferring the neuron state's shape.
  • Figure 3: This code achieves the same result as Fig. \ref{['fig:slax_snn']} but instead follows the typical Flax RNN conventions. A wrapping function is then applied to the defined model to produce a Slax-compatible neuron. While this requires less code, the layer features must now be specified when defining the layer.
  • Figure 4: A diagram of an SNN with complex recurrent connections and its accompanying code. The connect function allows the user to specify recurrent and skip connections.
  • Figure 5: The code above defines a Flax SNN model. The left code block uses Slax LIF neurons while the right code block uses a Flax from Fig. \ref{['fig:flax_snn']}. Both are functionally equivalent and are fully compatible with Flax's API for RNNs, such as , which loops the layer and its state through the time dimension of the input.
  • ...and 6 more figures