Table of Contents
Fetching ...

High-Performance Temporal Reversible Spiking Neural Networks with $O(L)$ Training Memory and $O(1)$ Inference Cost

JiaKui Hu, Man Yao, Xuerui Qiu, Yuhong Chou, Yuxuan Cai, Ning Qiao, Yonghong Tian, Bo XU, Guoqi Li

TL;DR

The paper tackles the memory and energy bottlenecks of multi-timestep spiking neural networks by introducing Temporal Reversible SNNs (T-RevSNN). By turning off temporal dynamics for most neurons and enabling reversible temporal transfer only at key spike layers, T-RevSNN achieves $O(L)$ training memory and $O(1)$ inference cost, while maintaining strong accuracy on ImageNet-1k and neuromorphic datasets. The approach combines multi-level temporal-reversible forward information transfer, input encoding grouping, and ConvNeXt-style SNN blocks with a ReZero-enhanced residual design. Compared with state-of-the-art CNN-based SNNs and Transformer-based baselines, T-RevSNN offers significant improvements in training memory (up to $8.6 imes$), training time (up to $2.0 imes$), and inference energy (up to $1.6 imes$), making large-scale, energy-efficient SNNs more practical.

Abstract

Multi-timestep simulation of brain-inspired Spiking Neural Networks (SNNs) boost memory requirements during training and increase inference energy cost. Current training methods cannot simultaneously solve both training and inference dilemmas. This work proposes a novel Temporal Reversible architecture for SNNs (T-RevSNN) to jointly address the training and inference challenges by altering the forward propagation of SNNs. We turn off the temporal dynamics of most spiking neurons and design multi-level temporal reversible interactions at temporal turn-on spiking neurons, resulting in a $O(L)$ training memory. Combined with the temporal reversible nature, we redesign the input encoding and network organization of SNNs to achieve $O(1)$ inference energy cost. Then, we finely adjust the internal units and residual connections of the basic SNN block to ensure the effectiveness of sparse temporal information interaction. T-RevSNN achieves excellent accuracy on ImageNet, while the memory efficiency, training time acceleration, and inference energy efficiency can be significantly improved by $8.6 \times$, $2.0 \times$, and $1.6 \times$, respectively. This work is expected to break the technical bottleneck of significantly increasing memory cost and training time for large-scale SNNs while maintaining high performance and low inference energy cost. Source code and models are available at: https://github.com/BICLab/T-RevSNN.

High-Performance Temporal Reversible Spiking Neural Networks with $O(L)$ Training Memory and $O(1)$ Inference Cost

TL;DR

The paper tackles the memory and energy bottlenecks of multi-timestep spiking neural networks by introducing Temporal Reversible SNNs (T-RevSNN). By turning off temporal dynamics for most neurons and enabling reversible temporal transfer only at key spike layers, T-RevSNN achieves training memory and inference cost, while maintaining strong accuracy on ImageNet-1k and neuromorphic datasets. The approach combines multi-level temporal-reversible forward information transfer, input encoding grouping, and ConvNeXt-style SNN blocks with a ReZero-enhanced residual design. Compared with state-of-the-art CNN-based SNNs and Transformer-based baselines, T-RevSNN offers significant improvements in training memory (up to ), training time (up to ), and inference energy (up to ), making large-scale, energy-efficient SNNs more practical.

Abstract

Multi-timestep simulation of brain-inspired Spiking Neural Networks (SNNs) boost memory requirements during training and increase inference energy cost. Current training methods cannot simultaneously solve both training and inference dilemmas. This work proposes a novel Temporal Reversible architecture for SNNs (T-RevSNN) to jointly address the training and inference challenges by altering the forward propagation of SNNs. We turn off the temporal dynamics of most spiking neurons and design multi-level temporal reversible interactions at temporal turn-on spiking neurons, resulting in a training memory. Combined with the temporal reversible nature, we redesign the input encoding and network organization of SNNs to achieve inference energy cost. Then, we finely adjust the internal units and residual connections of the basic SNN block to ensure the effectiveness of sparse temporal information interaction. T-RevSNN achieves excellent accuracy on ImageNet, while the memory efficiency, training time acceleration, and inference energy efficiency can be significantly improved by , , and , respectively. This work is expected to break the technical bottleneck of significantly increasing memory cost and training time for large-scale SNNs while maintaining high performance and low inference energy cost. Source code and models are available at: https://github.com/BICLab/T-RevSNN.
Paper Structure (17 sections, 12 equations, 5 figures, 7 tables)

This paper contains 17 sections, 12 equations, 5 figures, 7 tables.

Figures (5)

  • Figure 1: Illustration of the temporal forward of vanilla SNN and T-RevSNN. (a) Vanilla SNNs unfold along the temporal, reusing all parameters at each timestep. All spiking neurons accomplish the temporal dynamics. In the image classification task, images are input repeatedly. Thus, the memory and inference costs of vanilla SNNs are $\mathcal{O}(L \times T)$ and $\mathcal{O}(T)$, respectively. (b) In T-RevSNN, we only allow the key spiking neurons (red spiking neurons in the figure) to pass temporal information. Coupled with the multi-level temporal reversible design, the memory cost of T-RevSNN is $\mathcal{O}(L)$. Moreover, the image is encoded only once. The encoded features are divided into $T$ groups, exploited as input for each timestep. Correspondingly, the entire SNN is also divided into $T$ independent sub-networks, which only share parameters and transfer temporal information at key spiking neurons. Thus, the inference of T-RevSNN is $\mathcal{O}(1)$.
  • Figure 2: Left and Right sub-figures are the cosine similarity between the spatial gradients calculated by baseline and case 1/2, respectively. In case 1/2, we only retain/discard the temporal gradients of the last layer of spiking neurons in each stage. We use spiking ResNet-18 to train on CIFAR-10. $T$ and $\tau$ are Timestep and decay, respectively. The temporal gradients of the final layer of each stage are more significant (Case 1, left sub-figure, high cosine similarity) than those of spiking neurons in preceding stages (Case 2, right sub-figure, low cosine similarity). Due to space constraints, cosine similarity calculation details are given in the supplementary material.
  • Figure 3: Temporal-reversible connection in T-RevSNN.
  • Figure 4: Basic SNN Block, following ConvNeXt-style liu2022convnet.
  • Figure 5: Illustration of the forward and backward. (a) Existing training methods do not change the forward of SNNs. (b) The backward of STBP unfolds simultaneously along the temporal and spatial dimensions. (c) The backward of OTTT/SLTT only unfolds along the spatial dimension. (d) The backward of S-RevSNN unfolds along the temporal and spatial dimensions but is reversible in spatial. (e) and (f) show the forward and backward of the temporal turn-off spiking neurons in the proposed method, respectively. (g) and (h) give the forward and backward of the temporal turn-on spiking neurons in T-RevSNN, which are basically consistent with the forward and backward in (a) and (b). The difference is that the backward in (h) is reversible, so only the membrane potentials of the last timestep needs to be stored.