ST$^3$: Accelerating Multimodal Large Language Model by Spatial-Temporal Visual Token Trimming
Jiedong Zhuang, Lu Lu, Ming Dai, Rui Hu, Jian Chen, Qiang Liu, Haoji Hu
TL;DR
This work tackles the computational bottleneck of multimodal large language models by analyzing visual-token attention and introducing ST$^3$, a retraining-free framework that trims visual tokens across layers (PVTP) and over time during generation (VTA). By exploiting the 'lazy layer' phenomenon and attention distributions, ST$^3$ achieves around a $2\times$ speedup with roughly $30\%$ KV cache usage while maintaining accuracy across diverse multimodal tasks. The approach is shown to outperform existing training-free pruning methods, reduces FLOPs by over half on representative models, and provides a plug-and-play solution for efficient inference in real-world applications. The work also provides extensive ablations and analysis of QK heredity, visual-token overlap, and attenuation choices, offering practical guidance for deployment and future improvements in MLLM token pruning.
Abstract
Multimodal large language models (MLLMs) enhance their perceptual capabilities by integrating visual and textual information. However, processing the massive number of visual tokens incurs a significant computational cost. Existing analysis of the MLLM attention mechanisms remains shallow, leading to coarse-grain token pruning strategies that fail to effectively balance speed and accuracy. In this paper, we conduct a comprehensive investigation of MLLM attention mechanisms with LLaVA. We find that numerous visual tokens and partial attention computations are redundant during the decoding process. Based on this insight, we propose Spatial-Temporal Visual Token Trimming ($\textbf{ST}^{3}$), a framework designed to accelerate MLLM inference without retraining. $\textbf{ST}^{3}$ consists of two primary components: 1) Progressive Visual Token Pruning (\textbf{PVTP}), which eliminates inattentive visual tokens across layers, and 2) Visual Token Annealing (\textbf{VTA}), which dynamically reduces the number of visual tokens in each layer as the generated tokens grow. Together, these techniques deliver around $\mathbf{2\times}$ faster inference with only about $\mathbf{30\%}$ KV cache memory compared to the original LLaVA, while maintaining consistent performance across various datasets. Crucially, $\textbf{ST}^{3}$ can be seamlessly integrated into existing pre-trained MLLMs, providing a plug-and-play solution for efficient inference.
