Table of Contents
Fetching ...

SpikingSSMs: Learning Long Sequences with Sparse and Parallel Spiking State Space Models

Shuaijie Shen, Chao Wang, Renzhuo Huang, Yan Zhong, Qinghai Guo, Zhichao Lu, Jianguo Zhang, Luziwei Leng

TL;DR

The paper introduces SpikingSSMs, a framework that combines leaky-integrate-and-fire (LIF) spiking neurons with state-space models (SSMs) to learn long-range dependencies with sparse, parallelizable computation. A lightweight Surrogate Dynamic Network (SDN) is proposed to approximate neuron dynamics, enabling parallel training and achieving orders-of-magnitude speedups without extra parameters at inference. Learnable thresholds further improve sparsity and accuracy by leveraging equivalence with input scaling. Empirically, SpikingSSMs reach competitive results on long-sequence benchmarks (LRA) and achieve state-of-the-art or near-state performance on WikiText-103 with far fewer parameters and high sparsity, indicating strong potential for energy-efficient backbones in large-scale language modeling.

Abstract

Known as low energy consumption networks, spiking neural networks (SNNs) have gained a lot of attention within the past decades. While SNNs are increasing competitive with artificial neural networks (ANNs) for vision tasks, they are rarely used for long sequence tasks, despite their intrinsic temporal dynamics. In this work, we develop spiking state space models (SpikingSSMs) for long sequence learning by leveraging on the sequence learning abilities of state space models (SSMs). Inspired by dendritic neuron structure, we hierarchically integrate neuronal dynamics with the original SSM block, meanwhile realizing sparse synaptic computation. Furthermore, to solve the conflict of event-driven neuronal dynamics with parallel computing, we propose a light-weight surrogate dynamic network which accurately predicts the after-reset membrane potential and compatible to learnable thresholds, enabling orders of acceleration in training speed compared with conventional iterative methods. On the long range arena benchmark task, SpikingSSM achieves competitive performance to state-of-the-art SSMs meanwhile realizing on average 90\% of network sparsity. On language modeling, our network significantly surpasses existing spiking large language models (spikingLLMs) on the WikiText-103 dataset with only a third of the model size, demonstrating its potential as backbone architecture for low computation cost LLMs.

SpikingSSMs: Learning Long Sequences with Sparse and Parallel Spiking State Space Models

TL;DR

The paper introduces SpikingSSMs, a framework that combines leaky-integrate-and-fire (LIF) spiking neurons with state-space models (SSMs) to learn long-range dependencies with sparse, parallelizable computation. A lightweight Surrogate Dynamic Network (SDN) is proposed to approximate neuron dynamics, enabling parallel training and achieving orders-of-magnitude speedups without extra parameters at inference. Learnable thresholds further improve sparsity and accuracy by leveraging equivalence with input scaling. Empirically, SpikingSSMs reach competitive results on long-sequence benchmarks (LRA) and achieve state-of-the-art or near-state performance on WikiText-103 with far fewer parameters and high sparsity, indicating strong potential for energy-efficient backbones in large-scale language modeling.

Abstract

Known as low energy consumption networks, spiking neural networks (SNNs) have gained a lot of attention within the past decades. While SNNs are increasing competitive with artificial neural networks (ANNs) for vision tasks, they are rarely used for long sequence tasks, despite their intrinsic temporal dynamics. In this work, we develop spiking state space models (SpikingSSMs) for long sequence learning by leveraging on the sequence learning abilities of state space models (SSMs). Inspired by dendritic neuron structure, we hierarchically integrate neuronal dynamics with the original SSM block, meanwhile realizing sparse synaptic computation. Furthermore, to solve the conflict of event-driven neuronal dynamics with parallel computing, we propose a light-weight surrogate dynamic network which accurately predicts the after-reset membrane potential and compatible to learnable thresholds, enabling orders of acceleration in training speed compared with conventional iterative methods. On the long range arena benchmark task, SpikingSSM achieves competitive performance to state-of-the-art SSMs meanwhile realizing on average 90\% of network sparsity. On language modeling, our network significantly surpasses existing spiking large language models (spikingLLMs) on the WikiText-103 dataset with only a third of the model size, demonstrating its potential as backbone architecture for low computation cost LLMs.
Paper Structure (40 sections, 27 equations, 7 figures, 12 tables)

This paper contains 40 sections, 27 equations, 7 figures, 12 tables.

Figures (7)

  • Figure 1: Architecture of SpikingSSM. (a) Forward computation graph of SpikingSSM in one layer. Operation $r$ denotes the reset mechanism. The learnable parameter $\theta$ denotes parameters that influence the spiking function $f$, such as the threshold. (b) Comparison of different SSMs. The original SSM outputs float point number. SpikingSSM replaces the non-linear function of original SSM with an LIF neuron, adding neuronal dynamics on a higher hierarchy. SAF denotes the spiking activation function. The left panel denotes the computation stage of different variables and their corresponding dimensions, with $D, N, L$ denoting the model dimension, the hidden dimension of SSM and the sequence length, respectively.
  • Figure 2: Comparison of membrane potential samples produced by different methods under the same input. The membrane potential predicted by the SDN (bottom) accurately approximates the ground truth produced by the spiking neuron (middle). Without reset the membrane potential significantly produces more spikes (top). The two black dashed lines denote the reset potential and the spiking threshold which are set to 0 and 1, respectively. Red points denote moments when spikes are generated, i.e. the membrane potential surpasses the threshold. Note that for the spiking neuron, the membrane potential is reset to 0 immediately once surpasses the threshold.
  • Figure 3: Training of SDN. The MSE loss and spiking accuracy on the test set are plotted here. Note that SDN already achieves sufficiently high accuracy after the first training epoch.
  • Figure 4: Spiking rate across all layers of SpikingSSMs on the sCIFAR10 and the WikiText-103 datasets.
  • Figure 5: The forward and backward computational graphs of three training methods. The black lines in forward graphs and the red lines in backward graphs denote data flows; the gray lines are not data flows, although they have corresponding parts in the forward graphs.
  • ...and 2 more figures