Exploring Token Pruning in Vision State Space Models
Zheng Zhan, Zhenglun Kong, Yifan Gong, Yushu Wu, Zichong Meng, Hangyu Zheng, Xuan Shen, Stratis Ioannidis, Wei Niu, Pu Zhao, Yanzhi Wang
TL;DR
This work tackles the challenge of accelerating vision state space models (SSMs) by introducing token pruning tailored to SSMs. It identifies that direct application of ViT-style pruning leads to large accuracy drops due to disruption of the sequential scan, and responds with pruning-aware hidden state alignment plus an SSM-specific token importance metric to guide pruning. The proposed framework achieves substantial computational savings with minimal performance loss, demonstrating strong ImageNet and COCO results (e.g., ~41% FLOPs reduction with 81.7% top-1 accuracy on pruned PlainMamba-L3) and providing ablations that validate the importance of alignment and clipping in the importance metric. These contributions offer practical acceleration for vision SSMs and deepen understanding of SSM scan behavior, guiding future research in this direction.
Abstract
State Space Models (SSMs) have the advantage of keeping linear computational complexity compared to attention modules in transformers, and have been applied to vision tasks as a new type of powerful vision foundation model. Inspired by the observations that the final prediction in vision transformers (ViTs) is only based on a subset of most informative tokens, we take the novel step of enhancing the efficiency of SSM-based vision models through token-based pruning. However, direct applications of existing token pruning techniques designed for ViTs fail to deliver good performance, even with extensive fine-tuning. To address this issue, we revisit the unique computational characteristics of SSMs and discover that naive application disrupts the sequential token positions. This insight motivates us to design a novel and general token pruning method specifically for SSM-based vision models. We first introduce a pruning-aware hidden state alignment method to stabilize the neighborhood of remaining tokens for performance enhancement. Besides, based on our detailed analysis, we propose a token importance evaluation method adapted for SSM models, to guide the token pruning. With efficient implementation and practical acceleration methods, our method brings actual speedup. Extensive experiments demonstrate that our approach can achieve significant computation reduction with minimal impact on performance across different tasks. Notably, we achieve 81.7\% accuracy on ImageNet with a 41.6\% reduction in the FLOPs for pruned PlainMamba-L3. Furthermore, our work provides deeper insights into understanding the behavior of SSM-based vision models for future research.
