Table of Contents
Fetching ...

Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX

Lukas König, Manuel Kuhn, David Kappel, Anand Subramoney

TL;DR

The Eventax framework is introduced, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling, and prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures.

Abstract

Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.

Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX

TL;DR

The Eventax framework is introduced, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling, and prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures.

Abstract

Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.
Paper Structure (28 sections, 10 equations, 6 figures, 4 tables)

This paper contains 28 sections, 10 equations, 6 figures, 4 tables.

Figures (6)

  • Figure 1: The NeuronModel interface. Users define custom neuron models by implementing: initial state, dynamics, spike condition, input spike handling, and post-spike reset.
  • Figure 2: Schematic of the multi-compartment neuron model and corresponding ODEs (Equations \ref{['eq:DLIF:vs_act']}-\ref{['eq:DLIF:M2']}). $X_1$ and $X_2$ are the inputs.
  • Figure 3: Diffrax event handling. The ODE solver advances the state $y$ until the event condition $g(t, y)$ changes sign. A root-finder then locates the exact event time $t_{\text{event}}$ between the last two solver steps, and the state is integrated from the last step to this exact event time. Gradients are backpropagated directly through the solver steps, while the dependence of $t_{\text{event}}$ on the model parameters is handled via the implicit function theorem.
  • Figure 4: Usage example of EventPropJax: training a LIF network on a TTFS task. EventPropJax is fully compatible with JAX and therefore works with JAX optimizer libraries, e.g. Optax. We can wrap any neuron with the AMOS wrapper to restrict every neuron to a single spike per trial.
  • Figure 5: Delayed-XOR task encoding (0, 1)
  • ...and 1 more figures