Table of Contents
Fetching ...

jaxsnn: Event-driven Gradient Estimation for Analog Neuromorphic Hardware

Eric Müller, Moritz Althaus, Elias Arnold, Philipp Spilger, Christian Pehle, Johannes Schemmel

TL;DR

Problem: gradient-based training for neuromorphic hardware is hindered by asynchronous spike data and time-continuous dynamics when using dense, time-grid based ML frameworks. Approach: a JAX-based library, jaxsnn, supports event-driven computation on spike representations with Autograd, via a differentiable EventProp algorithm and a vectorized, PyTree-friendly simulator that advances from event to event using a differentiable root solver for the next spike time, and integrates forward hardware execution (BrainScaleS-2) plus a hardware-mock mode. Contributions: implementation of event-driven gradient estimation in a flexible framework, direct compatibility with neuromorphic backends during the forward pass, and validation on Yin-Yang showing strong accuracy. Significance: this work bridges neuromorphic hardware and contemporary ML tooling, enabling efficient, flexible training of spiking neural networks on analog hardware.

Abstract

Traditional neuromorphic hardware architectures rely on event-driven computation, where the asynchronous transmission of events, such as spikes, triggers local computations within synapses and neurons. While machine learning frameworks are commonly used for gradient-based training, their emphasis on dense data structures poses challenges for processing asynchronous data such as spike trains. This problem is particularly pronounced for typical tensor data structures. In this context, we present a novel library (jaxsnn) built on top of JAX, that departs from conventional machine learning frameworks by providing flexibility in the data structures used and the handling of time, while maintaining Autograd functionality and composability. Our library facilitates the simulation of spiking neural networks and gradient estimation, with a focus on compatibility with time-continuous neuromorphic backends, such as the BrainScaleS-2 system, during the forward pass. This approach opens avenues for more efficient and flexible training of spiking neural networks, bridging the gap between traditional neuromorphic architectures and contemporary machine learning frameworks.

jaxsnn: Event-driven Gradient Estimation for Analog Neuromorphic Hardware

TL;DR

Problem: gradient-based training for neuromorphic hardware is hindered by asynchronous spike data and time-continuous dynamics when using dense, time-grid based ML frameworks. Approach: a JAX-based library, jaxsnn, supports event-driven computation on spike representations with Autograd, via a differentiable EventProp algorithm and a vectorized, PyTree-friendly simulator that advances from event to event using a differentiable root solver for the next spike time, and integrates forward hardware execution (BrainScaleS-2) plus a hardware-mock mode. Contributions: implementation of event-driven gradient estimation in a flexible framework, direct compatibility with neuromorphic backends during the forward pass, and validation on Yin-Yang showing strong accuracy. Significance: this work bridges neuromorphic hardware and contemporary ML tooling, enabling efficient, flexible training of spiking neural networks on analog hardware.

Abstract

Traditional neuromorphic hardware architectures rely on event-driven computation, where the asynchronous transmission of events, such as spikes, triggers local computations within synapses and neurons. While machine learning frameworks are commonly used for gradient-based training, their emphasis on dense data structures poses challenges for processing asynchronous data such as spike trains. This problem is particularly pronounced for typical tensor data structures. In this context, we present a novel library (jaxsnn) built on top of JAX, that departs from conventional machine learning frameworks by providing flexibility in the data structures used and the handling of time, while maintaining Autograd functionality and composability. Our library facilitates the simulation of spiking neural networks and gradient estimation, with a focus on compatibility with time-continuous neuromorphic backends, such as the BrainScaleS-2 system, during the forward pass. This approach opens avenues for more efficient and flexible training of spiking neural networks, bridging the gap between traditional neuromorphic architectures and contemporary machine learning frameworks.
Paper Structure (4 sections, 3 figures, 1 table, 1 algorithm)

This paper contains 4 sections, 3 figures, 1 table, 1 algorithm.

Figures (3)

  • Figure 2: Visualization of the two-dimensional Yin-Yang dataset kriener2021yin. The dataset has three different classes (black, white and gray) and is not linearly separable. Each two-dimensional data point can be described by the spike times of two input neurons; in our experiment, we use mirrored inputs and a bias spike to obtain a five-dimensional input. The time axis is defined in relation to the synaptic time constant.
  • Figure 3: \ref{['subfig:step']} illustrates the data flow within the step function (cf. loop body of \ref{['lst:step-algorithm-multiple-neurons']}): given initial State $S_{t_n}$ and input data the next event time $t_{n+1}$ is determined, the continuous dynamics are progressed to this time $S_{t_{n+1}-}$, and the transition/discontinuity is applied, yielding $S_{t_{n+1}}$ and an output spike. In \ref{['subfig:axis-mapping-pytree']}, each iteration of the step function takes some state $S_t$ and returns an updated state $S_{t+1}$ and some Event $E_{t+1}$. When scanning over the step function, the final state $S_\text{last}$ and a pytree$\bm{E}$ are returned. A new axis is created and mapped down to all leaves of $\bm{E}$. The two fields become arrays over their previous type.
  • Figure 4: Hardware training: The forward pass is delegated to the system pehle2022brainscales2_nopreprint_nourlmueller2022scalable_noeprint, which returns a list of spikes. Afterwards, an additional step provides information about the synaptic current $I_\text{spike}$ at the time of the spike $t_\text{spike}$, which is numerically calculated. Based on a loss function, the EventProp algorithm wunderlich2021event uses these spikes to compute weight updates $\Delta{}W$ for the parameters $W$.