Table of Contents
Fetching ...

Bullet Trains: Parallelizing Training of Temporally Precise Spiking Neural Networks

Todd Morrill, Christian Pehle, Anthony Zador

Abstract

Continuous-time, event-native spiking neural networks (SNNs) operate strictly on spike events, treating spike timing and ordering as the representation rather than an artifact of time discretization. This viewpoint aligns with biological computation and with the native resolution of event sensors and neuromorphic processors, while enabling compute and memory that scale with the number of events. However, two challenges hinder practical, end-to-end trainable event-based SNN systems: 1) exact charge--fire--reset dynamics impose inherently sequential processing of input spikes, and 2) precise spike times must be solved without time bins. We address both. First, we use parallel associative scans to consume multiple input spikes at once, yielding up to 44x speedups over sequential simulation while retaining exact hard-reset dynamics. Second, we implement differentiable spike-time solvers that compute spike times to machine precision without discrete-time approximations or restrictive analytic assumptions. We demonstrate the viability of training SNNs using our solutions on four event-based datasets on GPUs.

Bullet Trains: Parallelizing Training of Temporally Precise Spiking Neural Networks

Abstract

Continuous-time, event-native spiking neural networks (SNNs) operate strictly on spike events, treating spike timing and ordering as the representation rather than an artifact of time discretization. This viewpoint aligns with biological computation and with the native resolution of event sensors and neuromorphic processors, while enabling compute and memory that scale with the number of events. However, two challenges hinder practical, end-to-end trainable event-based SNN systems: 1) exact charge--fire--reset dynamics impose inherently sequential processing of input spikes, and 2) precise spike times must be solved without time bins. We address both. First, we use parallel associative scans to consume multiple input spikes at once, yielding up to 44x speedups over sequential simulation while retaining exact hard-reset dynamics. Second, we implement differentiable spike-time solvers that compute spike times to machine precision without discrete-time approximations or restrictive analytic assumptions. We demonstrate the viability of training SNNs using our solutions on four event-based datasets on GPUs.
Paper Structure (39 sections, 1 theorem, 37 equations, 13 figures, 5 tables, 1 algorithm)

This paper contains 39 sections, 1 theorem, 37 equations, 13 figures, 5 tables, 1 algorithm.

Key Result

Lemma 2.1

Let $\mathbf{s}_1 = M_1 \mathbf{s}_0 + \mathbf{b_1}$ and $\mathbf{s}_2 = M_2 \mathbf{s}_1 + \mathbf{b_2}$. Then the operator Combine defined as $\textit{Combine}((M_2, \mathbf{b}_2), (M_1, \mathbf{b}_1)) = (M_2 M_1, M_2 \mathbf{b_1} + \mathbf{b}_2)$ is associative, i.e., $\textit{Combine}(\textit{Co

Figures (13)

  • Figure 1: Our parallel method achieves up to 44x speedups over serial processing of spike events on several event-based datasets while retaining exact hard-reset dynamics. The plot shows results for the Spiking Heidelberg Digits (SHD) dataset.
  • Figure 2: Our contributions visualized---state trajectory of a LIF neuron consuming input spikes in parallel and producing an output spike with our Newton-Raphson solver. A shows the neuron jumping from the initial state $\mathbf{s}_0$ at time 0 (composed of membrane potential and synaptic current) directly to all future states $\mathbf{s}_1 = f_1(\mathbf{s}_0), \mathbf{s}_2 = f_2(\mathbf{s}_0), \ldots$ in parallel using associative scans. We use lightweight analytical checks to determine if an output spike will occur in the interval between any consecutive pair of input spikes by computing---in parallel---$t_{V_{\max}}(\mathbf{s})$, the time of maximum voltage starting at the interval's left endpoint, and $V(t_{V_{\max}})$ the voltage at that time, namely the maximum voltage value. The voltage state is hard reset at the spike time. B shows our iterative Newton-Raphson solver finding the output spike time $t^{\star}$, computed to machine precision.
  • Figure 3: Discrete-time binning causes quantization errors and loses spike ordering. Here, two spikes occurring at different times (0.0021s and 0.0027s) both round to 0.003s, making them indistinguishable to downstream neurons.
  • Figure 4: An example input queue feeding a neuron and producing an output queue. The input chunk size is 4 spikes, which means 4 input spikes are consumed in parallel. After consuming 3 input spikes in chunk 1, the neuron produces an output spike. The work done to consume the 4th input spike in chunk 1 is discarded, and the next chunk resumes processing input spikes immediately after the output spike time.
  • Figure 5: Spike time as a function of synaptic weight for different time constant configurations. Discrete time approximation with $\Delta t = 1\text{ms}$ diverges from the ground truth spike time at nearly all weight values, while our method closely tracks the ground truth spike time.
  • ...and 8 more figures

Theorems & Definitions (2)

  • Lemma 2.1
  • proof