Slax: A Composable JAX Library for Rapid and Flexible Prototyping of Spiking Neural Networks
Thomas M. Summe, Siddharth Joshi
TL;DR
The paper addresses the need for rapid, flexible experimentation with spiking neural network learning algorithms beyond backpropagation through time with surrogate gradients. It introduces Slax, a JAX/Flax-based library with modular LIF neurons, a connect API for recurrent architectures, and a suite of learning rules (OSTL, OTTT, OTPE, RTRL, FPTT) plus online/offline training and visualization tools. The authors demonstrate competitive performance relative to other SNN frameworks and emphasize easy integration with Flax and the broader JAX ecosystem. This framework enables researchers to compare learning rules, study gradient behavior, and accelerate prototyping of energy-efficient SNN algorithms, with practical impact on SNN research workflows.
Abstract
Recent advances to algorithms for training spiking neural networks (SNNs) often leverage their unique dynamics. While backpropagation through time (BPTT) with surrogate gradients dominate the field, a rich landscape of alternatives can situate algorithms across various points in the performance, bio-plausibility, and complexity landscape. Evaluating and comparing algorithms is currently a cumbersome and error-prone process, requiring them to be repeatedly re-implemented. We introduce Slax, a JAX-based library designed to accelerate SNN algorithm design, compatible with the broader JAX and Flax ecosystem. Slax provides optimized implementations of diverse training algorithms, allowing direct performance comparison. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training.
