Table of Contents
Fetching ...

Spyx: A Library for Just-In-Time Compiled Optimization of Spiking Neural Networks

Kade M. Heckel, Thomas Nowotny

TL;DR

This work tackles the challenge of efficiently training and deploying Spiking Neural Networks by bridging Python-based workflows with GPU/TPU-accelerated, JIT-compiled compute. It introduces Spyx, a JAX-based SNN library built on Haiku that pre-stages data in vRAM and compiles the entire training loop into a single high-performance program, while providing modular components for surrogate gradients, neuron models, losses, data handling, and NIR-based export/import. Key contributions include a flexible surrogate-gradient system, LIF neuron implementations, data compression and on-GPU augmentation, and the Neuromorphic Intermediate Representation interface for cross-hardware deployment; benchmark results demonstrate Spyx achieving competitive runtimes with PyTorch-based and CUDA-accelerated frameworks on SHD and NMNIST. The findings indicate that JAX-based JIT compilation can deliver near-CUDA performance for SNN training, enabling rapid experimentation and broader deployment across AI accelerators and neuromorphic hardware via NIR.

Abstract

As the role of artificial intelligence becomes increasingly pivotal in modern society, the efficient training and deployment of deep neural networks have emerged as critical areas of focus. Recent advancements in attention-based large neural architectures have spurred the development of AI accelerators, facilitating the training of extensive, multi-billion parameter models. Despite their effectiveness, these powerful networks often incur high execution costs in production environments. Neuromorphic computing, inspired by biological neural processes, offers a promising alternative. By utilizing temporally-sparse computations, Spiking Neural Networks (SNNs) offer to enhance energy efficiency through a reduced and low-power hardware footprint. However, the training of SNNs can be challenging due to their recurrent nature which cannot as easily leverage the massive parallelism of modern AI accelerators. To facilitate the investigation of SNN architectures and dynamics researchers have sought to bridge Python-based deep learning frameworks such as PyTorch or TensorFlow with custom-implemented compute kernels. This paper introduces Spyx, a new and lightweight SNN simulation and optimization library designed in JAX. By pre-staging data in the expansive vRAM of contemporary accelerators and employing extensive JIT compilation, Spyx allows for SNN optimization to be executed as a unified, low-level program on NVIDIA GPUs or Google TPUs. This approach achieves optimal hardware utilization, surpassing the performance of many existing SNN training frameworks while maintaining considerable flexibility.

Spyx: A Library for Just-In-Time Compiled Optimization of Spiking Neural Networks

TL;DR

This work tackles the challenge of efficiently training and deploying Spiking Neural Networks by bridging Python-based workflows with GPU/TPU-accelerated, JIT-compiled compute. It introduces Spyx, a JAX-based SNN library built on Haiku that pre-stages data in vRAM and compiles the entire training loop into a single high-performance program, while providing modular components for surrogate gradients, neuron models, losses, data handling, and NIR-based export/import. Key contributions include a flexible surrogate-gradient system, LIF neuron implementations, data compression and on-GPU augmentation, and the Neuromorphic Intermediate Representation interface for cross-hardware deployment; benchmark results demonstrate Spyx achieving competitive runtimes with PyTorch-based and CUDA-accelerated frameworks on SHD and NMNIST. The findings indicate that JAX-based JIT compilation can deliver near-CUDA performance for SNN training, enabling rapid experimentation and broader deployment across AI accelerators and neuromorphic hardware via NIR.

Abstract

As the role of artificial intelligence becomes increasingly pivotal in modern society, the efficient training and deployment of deep neural networks have emerged as critical areas of focus. Recent advancements in attention-based large neural architectures have spurred the development of AI accelerators, facilitating the training of extensive, multi-billion parameter models. Despite their effectiveness, these powerful networks often incur high execution costs in production environments. Neuromorphic computing, inspired by biological neural processes, offers a promising alternative. By utilizing temporally-sparse computations, Spiking Neural Networks (SNNs) offer to enhance energy efficiency through a reduced and low-power hardware footprint. However, the training of SNNs can be challenging due to their recurrent nature which cannot as easily leverage the massive parallelism of modern AI accelerators. To facilitate the investigation of SNN architectures and dynamics researchers have sought to bridge Python-based deep learning frameworks such as PyTorch or TensorFlow with custom-implemented compute kernels. This paper introduces Spyx, a new and lightweight SNN simulation and optimization library designed in JAX. By pre-staging data in the expansive vRAM of contemporary accelerators and employing extensive JIT compilation, Spyx allows for SNN optimization to be executed as a unified, low-level program on NVIDIA GPUs or Google TPUs. This approach achieves optimal hardware utilization, surpassing the performance of many existing SNN training frameworks while maintaining considerable flexibility.
Paper Structure (15 sections, 1 equation, 2 tables)