Table of Contents
Fetching ...

Smooth Exact Gradient Descent Learning in Spiking Neural Networks

Christian Klos, Raoul-Martin Memmesheimer

TL;DR

This work advances gradient-based learning for spiking neural networks by introducing exact gradient descent that operates on continuously evolving spiking dynamics. It introduces pseudospikes to propagate learning signals beyond trial ends and develops two pseudospike schemes within QIF and informally comparable LIF frameworks, enabling smooth spike-time optimization and spike-addition/removal. The analysis proves that spike times depend smoothly on initial conditions, input weights, and input spike times, and that gradient continuity is maintained even as spike orders change. The approach is demonstrated on recurrent and deep networks, including MNIST-style tasks, showing precise control over spike timings and promising applications for neuromorphic training with exact gradients.

Abstract

Gradient descent prevails in artificial neural network training, but seems inept for spiking neural networks as small parameter changes can cause sudden, disruptive (dis-)appearances of spikes. Here, we demonstrate exact gradient descent based on continuously changing spiking dynamics. These are generated by neuron models whose spikes vanish and appear at the end of a trial, where it cannot influence subsequent dynamics. This also enables gradient-based spike addition and removal. We illustrate our scheme with various tasks and setups, including recurrent and deep, initially silent networks.

Smooth Exact Gradient Descent Learning in Spiking Neural Networks

TL;DR

This work advances gradient-based learning for spiking neural networks by introducing exact gradient descent that operates on continuously evolving spiking dynamics. It introduces pseudospikes to propagate learning signals beyond trial ends and develops two pseudospike schemes within QIF and informally comparable LIF frameworks, enabling smooth spike-time optimization and spike-addition/removal. The analysis proves that spike times depend smoothly on initial conditions, input weights, and input spike times, and that gradient continuity is maintained even as spike orders change. The approach is demonstrated on recurrent and deep networks, including MNIST-style tasks, showing precise control over spike timings and promising applications for neuromorphic training with exact gradients.

Abstract

Gradient descent prevails in artificial neural network training, but seems inept for spiking neural networks as small parameter changes can cause sudden, disruptive (dis-)appearances of spikes. Here, we demonstrate exact gradient descent based on continuously changing spiking dynamics. These are generated by neuron models whose spikes vanish and appear at the end of a trial, where it cannot influence subsequent dynamics. This also enables gradient-based spike addition and removal. We illustrate our scheme with various tasks and setups, including recurrent and deep, initially silent networks.
Paper Structure (25 sections, 66 equations, 5 figures)

This paper contains 25 sections, 66 equations, 5 figures.

Figures (5)

  • Figure S1: First type of pseudodynamics and pseudospikes. The figure shows in panels (a-c) results of simulations of a single neuron (neuron 1) that receives a single input spike; another neuron (neuron 2) is connected to neuron 1 and receives its output spikes (d-f). (a) Schematics of neuron 1, highlighting that it has a single input connection with weight $w$ and a single output connection. (b) Ordinary and pseudodynamics of neuron 1 for two different weight values. Left, blue traces: $w=5w_{\min}$, where $w_{\min}$ is the weight at which an ordinary spike appears at infinity (cf. main text Fig. 1). During the ordinary dynamics (white background), the input current (lower panel) due to the input spike is strong enough to induce an ordinary spike (upper panel, light blue vertical tick). Right, orange traces: $w=0$, the neuron does not generate ordinary spikes. During the pseudodynamics (gray background) the input current is set to a constant, suprathreshold value. This value depends on the input current at the trial end, $I_1(T)$, and is therefore different for the blue and orange traces. The pseudodynamics start at $V_1(T)$ and generate pseudospikes. The pseudodynamics continue until the desired number of spikes is generated, which we here assume to be three. (c) Times of the three spikes of neuron 1 and their derivatives with respect to the input weight, as a function of the input weight. The spike times and their derivatives are continuous, which means that gradient descent can be used to smoothly shift spike times into the trial. Vertical lines correspond to similarly colored examples shown in (b). (d) Schematics of neuron 2, highlighting that it receives a connection with weight $w_{21}$ from neuron 1. (e) Same as (b) but for neuron 2. The value of the input drive during the pseudodynamics depends on $I_2(T)$ and on the first pseudospike time of the presynaptic neuron 1 (dependency highlighted in red). Therefore it is different for the blue and orange traces. (f) Same as (c) but for neuron 2. The spike times are continuous and mostly smooth. Discontinuities of the derivatives of pseudospike times (insets) appear when a spike of the presynaptic neuron 1 crosses the trial end $T$. Gradient descent can be used to shift spike times into the trial even if neither neuron 2 nor neuron 1 spike during the trial.
  • Figure S2: Second type of pseudodynamics and pseudospikes. The figure shows the results of simulations in a basic two-layer network with two hidden neurons and one output neuron. There is one input at the beginning of the trial, which inhibits hidden neuron $2$, and one input a bit later, which excites both hidden neurons by $w$. Hidden neuron $1$ excites the output neuron, hidden neuron $2$ inhibits it. (a) Voltage traces of the output and the two hidden neurons for increasing $w$ plotted in increasing color intensity. The pseudodynamics with $d=2$ take place within $(T_1,T_1+1/d]$ and $(T_2,T_2+1/d]$ in the hidden and the output neurons, respectively. Solid, dashed and dashed-dotted vertical gray lines indicate $T_1$, $T_1+1/d=T_2$ and $T_2+1/d$, respectively. (b) Spike times as a function of $w$ (blue, orange: first, second spike time of the different neurons). For increasing $w$ there are transitions from an active pseudospike to an ordinary spike and simultaneously from an inactive to an active pseudospike, first in hidden neuron $1$ then in $2$. The insets show closeups of the curves around the corresponding weight values ($w\approx 2.47, 3.43$, solid gray vertical lines; spike time axis magnifications differ). The spiking of the hidden neurons and its temporal change trigger similar transitions in the output neuron. Dotted and solid vertical lines indicate weight values of traces displayed in (a). (c) like (b) for the gradient of the spike times with respect to $w$. The curves in (b,c) are continuous, because the spike times are smooth in $w$. This holds in particular at the transitions between inactive and active pseudospikes and between active pseudospikes and ordinary spikes.
  • Figure S3: Spike times of a QIF with a single exponentially decaying input arriving at $t=0$. (a) The output spike times $t_\text{sp}\xspace\xspace$ of the QIF form continuous curves without kinks in $(w,t)$-space (blue, red, purple: first, second, third output spike time), which start at $T$ or at $w_\text{min}$ and end at $w_\text{max}\xspace$ ($T=10$, i.e. ten times the membrane time constant, $w_\text{min}=-8.5, w_\text{max}\xspace=60$). They are the graphs of smooth functions $t_\text{sp}\xspace\xspace(w)$. (b) Derivative of the output spike times with respect to $w$ (blue, red, purple: derivative of first, second, third output spike time). $\frac{\partial t_\text{sp}\xspace\xspace}{\partial w}$ is continuous. All derivative graphs start at finite values of $\partial t_\text{sp}\xspace\xspace/\partial w$, since the trial duration $T$ is finite. Starting points with $w>w_\text{min}$ correspond to points where $t_\text{sp}\xspace\xspace(w)$ starts to fall below $T$. Near these points, the derivatives assume large negative values. (c) Example traces $\phi(t)$ for different values of $w$ (from left to right: $w=-5, 16.3, 16.7, 57$, highlighted by light gray vertical dotted lines in (a)) show first one and then a second and third spike. Spikes appear at the end of the trial and then shift to earlier times with increasing $w$.
  • Figure S4: Change of output spike times when an input time changes. (a) The output spike times $t_\text{sp}\xspace\xspace$ (blue, red: first, second output spike time) are smooth functions of the input spike time $t_i$. There are no jumps or kinks in the graphs, also if $t_i$ crosses other input spike times $t_j$ (gray dashed vertical lines: $t_i=t_j$) or if it crosses output spike times (orange diagonal: $t_\text{sp}\xspace\xspace=t_i$, gray circles: crossing points of actual output spike times with $t_i$) or if the output spike times cross other input spike times $t_j$ (gray dashed horizontal lines: $t_\text{sp}\xspace\xspace=t_j$, partially crossed by blue curve). (b) The derivative $\partial t_\text{sp}\xspace\xspace/\partial t_i$ confirms the smoothness of the function $t_\text{sp}\xspace\xspace(t_i)$: It is continuous also at points where $t_i$ crosses other $t_j$ (gray dashed vertical lines) or where it agrees with actual output spike times (gray vertical lines). Inset: magnification of the range where derivatives are small, highlighting in particular the zero derivative when $t_i$ is larger than $t_\text{sp}\xspace\xspace$. The curves start and end at $w$ values where $t_\text{sp}\xspace\xspace(w)$ enters or exits the trial. (c) Example traces of $\phi(t)$ (upper panels) and $I(t)$ (lower panels) at different salient $t_i$ values (highlighted by light gray dotted lines in (a); gray dashed vertical lines in (c): $t_j$, orange vertical line: $t_i$): at the crossing of $t_i$ and a $t_j$ (trace one: $t_i=1$), closely before and after a fast change in the first spike time preceding an entering of the second spike time (traces two and three: $t_i=2.1, 2.22$) and close to the crossing of the second output spike time and $t_i$ (last trace: $t_i=6.92$).
  • Figure S5: Change of output spike times when the strength of one of multiple inputs changes. (a) The output spike times $t_\text{sp}\xspace\xspace$ (blue, red: first, second output spike time) are smooth functions of the input strength $w_i$ arriving at $t_i$ (orange horizontal line: $t_\text{sp}\xspace\xspace=t_i$). There are no jumps or kinks in the graphs, also when $t_\text{sp}\xspace\xspace$ crosses input spike times (gray dashed horizontal lines: $t_\text{sp}\xspace\xspace=t_j$, partially crossed by blue curve). (b) The derivative $\partial t_\text{sp}\xspace\xspace/\partial w_i$ confirms this smoothness. It is continuous also at values of $w_i$ where the output spike times cross input spike times (gray vertical lines; inset: magnification of the region around the crossing with smallest $w_i$). (c) Example traces of $\phi(t)$ at $w_i$ values around the fast change and the first crossing of the first $t_\text{sp}\xspace\xspace$ with a $t_j$ ($w_i=-3.2,-3.041,-2.957,-2.7$, highlighted by light gray dotted lines in (a); gray dashed vertical lines: $t_j$, orange vertical line: $t_i$).