Table of Contents
Fetching ...

AT-SNN: Adaptive Tokens for Vision Transformer on Spiking Neural Network

Donghwa Kang, Youngmoon Lee, Eun-Kyu Lee, Brent Kang, Jinkyu Lee, Hyeongboo Baek

TL;DR

This work addresses the high energy cost of SNN-based vision transformers by extending adaptive computation time (ACT) to a two-dimensional halting framework across timesteps and encoder blocks, together with a token-merge mechanism to reduce token counts. It introduces AT-SNN, formulates token-level halting scores $h^{l,t}_{k}$ and accumulators $H_k(L',T')$, and optimizes a combined loss $\mathcal{L}_{overall}$ to encourage early, accurate halting. Implemented on Spikformer, AT-SNN achieves higher accuracy with fewer tokens and lower energy consumption than state-of-the-art CNN- and transformer-based SNN methods across CIFAR-10, CIFAR-100, and TinyImageNet, while providing interpretable token-processing heatmaps. The approach advances practical deployment of energy-efficient SNN ViTs by balancing computation across temporal and spatial dimensions and highlighting the benefit of temporally-aware token processing.

Abstract

In the training and inference of spiking neural networks (SNNs), direct training and lightweight computation methods have been orthogonally developed, aimed at reducing power consumption. However, only a limited number of approaches have applied these two mechanisms simultaneously and failed to fully leverage the advantages of SNN-based vision transformers (ViTs) since they were originally designed for convolutional neural networks (CNNs). In this paper, we propose AT-SNN designed to dynamically adjust the number of tokens processed during inference in SNN-based ViTs with direct training, wherein power consumption is proportional to the number of tokens. We first demonstrate the applicability of adaptive computation time (ACT), previously limited to RNNs and ViTs, to SNN-based ViTs, enhancing it to discard less informative spatial tokens selectively. Also, we propose a new token-merge mechanism that relies on the similarity of tokens, which further reduces the number of tokens while enhancing accuracy. We implement AT-SNN to Spikformer and show the effectiveness of AT-SNN in achieving high energy efficiency and accuracy compared to state-of-the-art approaches on the image classification tasks, CIFAR10, CIFAR-100, and TinyImageNet. For example, our approach uses up to 42.4% fewer tokens than the existing best-performing method on CIFAR-100, while conserving higher accuracy.

AT-SNN: Adaptive Tokens for Vision Transformer on Spiking Neural Network

TL;DR

This work addresses the high energy cost of SNN-based vision transformers by extending adaptive computation time (ACT) to a two-dimensional halting framework across timesteps and encoder blocks, together with a token-merge mechanism to reduce token counts. It introduces AT-SNN, formulates token-level halting scores and accumulators , and optimizes a combined loss to encourage early, accurate halting. Implemented on Spikformer, AT-SNN achieves higher accuracy with fewer tokens and lower energy consumption than state-of-the-art CNN- and transformer-based SNN methods across CIFAR-10, CIFAR-100, and TinyImageNet, while providing interpretable token-processing heatmaps. The approach advances practical deployment of energy-efficient SNN ViTs by balancing computation across temporal and spatial dimensions and highlighting the benefit of temporally-aware token processing.

Abstract

In the training and inference of spiking neural networks (SNNs), direct training and lightweight computation methods have been orthogonally developed, aimed at reducing power consumption. However, only a limited number of approaches have applied these two mechanisms simultaneously and failed to fully leverage the advantages of SNN-based vision transformers (ViTs) since they were originally designed for convolutional neural networks (CNNs). In this paper, we propose AT-SNN designed to dynamically adjust the number of tokens processed during inference in SNN-based ViTs with direct training, wherein power consumption is proportional to the number of tokens. We first demonstrate the applicability of adaptive computation time (ACT), previously limited to RNNs and ViTs, to SNN-based ViTs, enhancing it to discard less informative spatial tokens selectively. Also, we propose a new token-merge mechanism that relies on the similarity of tokens, which further reduces the number of tokens while enhancing accuracy. We implement AT-SNN to Spikformer and show the effectiveness of AT-SNN in achieving high energy efficiency and accuracy compared to state-of-the-art approaches on the image classification tasks, CIFAR10, CIFAR-100, and TinyImageNet. For example, our approach uses up to 42.4% fewer tokens than the existing best-performing method on CIFAR-100, while conserving higher accuracy.
Paper Structure (18 sections, 11 equations, 10 figures, 1 table, 1 algorithm)

This paper contains 18 sections, 11 equations, 10 figures, 1 table, 1 algorithm.

Figures (10)

  • Figure 1: Accuracy comparison of lightweight computation methods in direct training on CIFAR-100.
  • Figure 2: Comparison of model architecture and halting-score accumulation paths among RNN, ViT, and SNN-based ViT when ACT is applied.
  • Figure 3: Estimation error and cosine similarity of tokens between consecutive blocks for (a) 12-layer ViT and (b) 4-layer SNN-based ViT (Spikformer) on CIFAR-100.
  • Figure 4: Token-level merging and masking example of AT-SNN: At the first timestep $t=1$, the input $x$ passes through the SPS, generating a token set $\mathcal{T}^{l, t}_{1:K^{0}}$. In the first block $\mathcal{B}^1$, with $\gamma = 2$ for nine tokens, the first and third tokens are merged into the second token by the merge module, and halting scores $h_k^{1,1}$ are added through inference. In subsequent blocks, tokens are merged based on their respective $\gamma$ values, and tokens with accumulated halting scores $H(l,t)$ of one or greater are masked. From the second timestep onwards, the same operations are repeated on the same input $x$. The halting score accumulation follows Eq. \ref{['eq:M_score']}, and the merged tokens within the same block across timesteps remain consistent due to the merge module's policy (detailed in Algo. \ref{['algo:merge']} and Sec. \ref{['subsec:eval_ablation']}). The vector values of merged or masked tokens are set to zero, and no further halting score is accumulated for the tokens. For ease of implementation, a masked token can also be considered a candidate for merging. For simplicity, this example does not include timestep-level halting score accumulation.
  • Figure 5: Original images (odd-numbered columns) and heatmaps showing the number of blocks (for four timesteps) each token processes (even-numbered columns) on TinyImageNet. Brighter colors indicate more processing per token. AT-SNN halts earlier on tokens that lack visual information.
  • ...and 5 more figures