Table of Contents
Fetching ...

Efficient Event-based Delay Learning in Spiking Neural Networks

Balázs Mészáros, James C. Knight, Thomas Nowotny

TL;DR

The paper tackles training spiking neural networks with learnable synaptic delays to extend temporal memory efficiently. It extends the exact-gradient EventProp framework to handle heterogeneous, trainable delays by deriving event-based gradients for both weights and delays, employing adjoint variables and a time-invariant loss variant. Implemented in mlGeNN, the approach demonstrates superior performance with far fewer parameters and substantial speed/memory gains across sequence, Yin-Yang, SHD, SSC, and Braille tasks. This work suggests that learnable delays can be effectively and efficiently integrated into recurrent SNNs, offering impactful benefits for temporally-rich neuromorphic learning and hardware deployment.

Abstract

Spiking Neural Networks (SNNs) compute using sparse communication and are attracting increased attention as a more energy-efficient alternative to traditional Artificial Neural Networks~(ANNs). While standard ANNs are stateless, spiking neurons are stateful and hence intrinsically recurrent, making them well-suited for spatio-temporal tasks. However, the duration of this intrinsic memory is limited by synaptic and membrane time constants. Delays are a powerful additional mechanism and, in this paper, we propose a novel event-based training method for SNNs with delays, grounded in the EventProp formalism which enables the calculation of exact gradients with respect to weights and delays. Our method supports multiple spikes per neuron and, to the best of our knowledge, is the first delay learning algorithm to be applied to recurrent SNNs. We evaluate our method on a simple sequence detection task, as well as the Yin-Yang, Spiking Heidelberg Digits, Spiking Speech Commands and Braille letter reading datasets, demonstrating that our algorithm can optimise delays from suboptimal initial conditions and enhance classification accuracy compared to architectures without delays. We also find that recurrent delays are particularly beneficial in small networks. Finally, we show that our approach uses less than half the memory of the current state-of-the-art delay-learning method and is up to 26x faster.

Efficient Event-based Delay Learning in Spiking Neural Networks

TL;DR

The paper tackles training spiking neural networks with learnable synaptic delays to extend temporal memory efficiently. It extends the exact-gradient EventProp framework to handle heterogeneous, trainable delays by deriving event-based gradients for both weights and delays, employing adjoint variables and a time-invariant loss variant. Implemented in mlGeNN, the approach demonstrates superior performance with far fewer parameters and substantial speed/memory gains across sequence, Yin-Yang, SHD, SSC, and Braille tasks. This work suggests that learnable delays can be effectively and efficiently integrated into recurrent SNNs, offering impactful benefits for temporally-rich neuromorphic learning and hardware deployment.

Abstract

Spiking Neural Networks (SNNs) compute using sparse communication and are attracting increased attention as a more energy-efficient alternative to traditional Artificial Neural Networks~(ANNs). While standard ANNs are stateless, spiking neurons are stateful and hence intrinsically recurrent, making them well-suited for spatio-temporal tasks. However, the duration of this intrinsic memory is limited by synaptic and membrane time constants. Delays are a powerful additional mechanism and, in this paper, we propose a novel event-based training method for SNNs with delays, grounded in the EventProp formalism which enables the calculation of exact gradients with respect to weights and delays. Our method supports multiple spikes per neuron and, to the best of our knowledge, is the first delay learning algorithm to be applied to recurrent SNNs. We evaluate our method on a simple sequence detection task, as well as the Yin-Yang, Spiking Heidelberg Digits, Spiking Speech Commands and Braille letter reading datasets, demonstrating that our algorithm can optimise delays from suboptimal initial conditions and enhance classification accuracy compared to architectures without delays. We also find that recurrent delays are particularly beneficial in small networks. Finally, we show that our approach uses less than half the memory of the current state-of-the-art delay-learning method and is up to 26x faster.
Paper Structure (15 sections, 50 equations, 8 figures, 5 tables)

This paper contains 15 sections, 50 equations, 8 figures, 5 tables.

Figures (8)

  • Figure 1: Illustration of the original EventProp formalism without delays. In a minimal example, a network has input neurons, one hidden layer and an output layer. Input spikes cause instantaneous jumps in the hidden $I$ variable (blue lines), which drives the $V$ variable (orange lines). When $V$ reaches the firing threshold, it is reset and spikes are emitted, which instantaneously jump the $I$ variable of the output neurons. This forward pass is followed by a backward pass, where the "blame" of each weight for the eventual loss is calculated. The adjoint variables $\lambda_V$ and $\lambda_I$ are proxies of this blame. The calculation occurs backwards in time. Here we illustrate the use of a readout and loss that are based on the maximum voltage of the output neurons. Accordingly, the loss causes a jump in the output $\lambda_V$ variable at the time when the maximum output voltage occurred in the forward pass. This blame is then transported as jumps of $\lambda_V$ of hidden neurons at the times when hidden spikes had occurred. Gradient updates to the hidden-to-output weights occur at the same time. Note that the plots are for illustrative purposes only and not to scale, other than matching pairs of $\lambda_V$ and $\lambda_I$ being at the same scale.
  • Figure 2: Illustration of the extended EventProp algorithm for SNNs with delays. In essence, the forward pass works in the same way as for networks without delays (Fig. \ref{['fig:basic_explain']}), except that the jump in post-synaptic $I$ variables occurs with a delay $d$. In the backward pass, this translates into backwards transport of blame at the original spike times but transporting values of post-synaptic adjoint variables at a delayed time (in forward time). Gradients with respect to weights and delays are accumulated at the same saved spike times and based on the same delayed quantities.
  • Figure 3: Illustration of the learning updates in the sequence detection task. A) Voltage $V$ and current $I$ in the forward pass. B) Adjoint variables $\lambda_V$ and $\lambda_I$. The loss is injected into $\lambda_V$ at $t_{\text{max}}$ when the output voltage reached maximum in the forward pass. This then propagates into $\lambda_I$ and eventually leads to delay updates at the saved spike times $t_1,t_2$. For neuron $3$, the update to $d_{32}$ at $t_2$ is negative and the update to $d_{3,1}$ at $t_1$ positive, meaning that (subject to delays being non-negative) the excitatory postsynaptic potential (EPSP) from neuron 2 is moved to earlier and the EPSP from neuron 1 to later, moving them close together. For neuron 4, the opposite is the case, where the updates are such that the EPSPs are moved apart. This is exactly what is needed to increase the maximal output voltage of neuron 3 and decrease the maximal output of neuron 4. When the other input class is active, all roles are inverted, again leading to the correct delay updates.
  • Figure 4: Left The Yin-Yang (YY) dataset, with temporal encoding of example datapoint highlighted by a blue dot. Right We generate separate training, validation, and test sets with 5000, 1000 and 1000 examples respectively; and report the test performance using the model which performed best on the validation set. We look at feedforward networks with and without delays. The purple crosses shows the results reported by goltz2024delgrad. The points show average accuracy, and the error bars show standard deviation over 8 runs. We also show all individual results as smaller data points.
  • Figure 5: Left An example of a speaker saying "five" from the Spiking Heidelberg Digits (SHD) dataset. Right The SHD dataset does not have a validation set; we perform early stopping when training accuracy does not improve for 15 epochs and report the corresponding test accuracy. Ff denotes a model with 2 feedforward hidden layers, and Rec denotes a model with one recurrently connected hidden layer. We implemented models with 128, 256, 512 and 1024 hidden neurons. We also show SOTA results by baronig2024advancing, and previous EventProp results by nowotny2022loss. The triangles show results of other delay learning methods that appear to have used the test set for validation hammouamrilearningsun2023learnabledeckers2024co. The points show average accuracy, and the error bars show standard deviation over 8 runs. We also show all individual results as smaller data points.
  • ...and 3 more figures