Table of Contents
Fetching ...

ST$^3$: Accelerating Multimodal Large Language Model by Spatial-Temporal Visual Token Trimming

Jiedong Zhuang, Lu Lu, Ming Dai, Rui Hu, Jian Chen, Qiang Liu, Haoji Hu

TL;DR

This work tackles the computational bottleneck of multimodal large language models by analyzing visual-token attention and introducing ST$^3$, a retraining-free framework that trims visual tokens across layers (PVTP) and over time during generation (VTA). By exploiting the 'lazy layer' phenomenon and attention distributions, ST$^3$ achieves around a $2\times$ speedup with roughly $30\%$ KV cache usage while maintaining accuracy across diverse multimodal tasks. The approach is shown to outperform existing training-free pruning methods, reduces FLOPs by over half on representative models, and provides a plug-and-play solution for efficient inference in real-world applications. The work also provides extensive ablations and analysis of QK heredity, visual-token overlap, and attenuation choices, offering practical guidance for deployment and future improvements in MLLM token pruning.

Abstract

Multimodal large language models (MLLMs) enhance their perceptual capabilities by integrating visual and textual information. However, processing the massive number of visual tokens incurs a significant computational cost. Existing analysis of the MLLM attention mechanisms remains shallow, leading to coarse-grain token pruning strategies that fail to effectively balance speed and accuracy. In this paper, we conduct a comprehensive investigation of MLLM attention mechanisms with LLaVA. We find that numerous visual tokens and partial attention computations are redundant during the decoding process. Based on this insight, we propose Spatial-Temporal Visual Token Trimming ($\textbf{ST}^{3}$), a framework designed to accelerate MLLM inference without retraining. $\textbf{ST}^{3}$ consists of two primary components: 1) Progressive Visual Token Pruning (\textbf{PVTP}), which eliminates inattentive visual tokens across layers, and 2) Visual Token Annealing (\textbf{VTA}), which dynamically reduces the number of visual tokens in each layer as the generated tokens grow. Together, these techniques deliver around $\mathbf{2\times}$ faster inference with only about $\mathbf{30\%}$ KV cache memory compared to the original LLaVA, while maintaining consistent performance across various datasets. Crucially, $\textbf{ST}^{3}$ can be seamlessly integrated into existing pre-trained MLLMs, providing a plug-and-play solution for efficient inference.

ST$^3$: Accelerating Multimodal Large Language Model by Spatial-Temporal Visual Token Trimming

TL;DR

This work tackles the computational bottleneck of multimodal large language models by analyzing visual-token attention and introducing ST, a retraining-free framework that trims visual tokens across layers (PVTP) and over time during generation (VTA). By exploiting the 'lazy layer' phenomenon and attention distributions, ST achieves around a speedup with roughly KV cache usage while maintaining accuracy across diverse multimodal tasks. The approach is shown to outperform existing training-free pruning methods, reduces FLOPs by over half on representative models, and provides a plug-and-play solution for efficient inference in real-world applications. The work also provides extensive ablations and analysis of QK heredity, visual-token overlap, and attenuation choices, offering practical guidance for deployment and future improvements in MLLM token pruning.

Abstract

Multimodal large language models (MLLMs) enhance their perceptual capabilities by integrating visual and textual information. However, processing the massive number of visual tokens incurs a significant computational cost. Existing analysis of the MLLM attention mechanisms remains shallow, leading to coarse-grain token pruning strategies that fail to effectively balance speed and accuracy. In this paper, we conduct a comprehensive investigation of MLLM attention mechanisms with LLaVA. We find that numerous visual tokens and partial attention computations are redundant during the decoding process. Based on this insight, we propose Spatial-Temporal Visual Token Trimming (), a framework designed to accelerate MLLM inference without retraining. consists of two primary components: 1) Progressive Visual Token Pruning (\textbf{PVTP}), which eliminates inattentive visual tokens across layers, and 2) Visual Token Annealing (\textbf{VTA}), which dynamically reduces the number of visual tokens in each layer as the generated tokens grow. Together, these techniques deliver around faster inference with only about KV cache memory compared to the original LLaVA, while maintaining consistent performance across various datasets. Crucially, can be seamlessly integrated into existing pre-trained MLLMs, providing a plug-and-play solution for efficient inference.
Paper Structure (19 sections, 10 equations, 12 figures, 14 tables, 3 algorithms)

This paper contains 19 sections, 10 equations, 12 figures, 14 tables, 3 algorithms.

Figures (12)

  • Figure 1: Comparison of various models on dataset ScienceQA_Img SQA with the circle size representing their FLOPs. Our method achieves the highest accuracy in 13B parameter models while maintaining the minimum FLOPs and decoding latency, even outperforming the smaller model LLaVA-1.5-7B.
  • Figure 2: Illustration of our method compared with existing visual token pruning methods. (a) LLaVA llava1.5 keeps all visual tokens. (b) FastV fastv prunes a fixed number of tokens in deeper layers. This layer-oblivious paradigm overlooks the variability of attention patterns across layers. (c) VTW VTW prunes all visual tokens in the latter half layers, leading to a permanent loss of visual information in deeper layers. Additionally, these three methods maintain the same quantity of visual tokens throughout the entire generation process, requiring a substantial KV cache memory budget. (d) Our method prunes inattentive visual tokens progressively as the layer goes deeper, while dynamically reducing tokens in the generation process. It maximizes the inference efficiency by exploiting the limit of the model's dependence on visual tokens.
  • Figure 3: Illustration of the similarity between the attention scores of all layers in LLaVA-1.5-7B. High similarity scores are distributed around the diagonal, indicating adjacent layers have more similar attention patterns.
  • Figure 4: (a) Attention weight in various layers. Visual attention exhibits a persistently low magnitude after the layer3. (b) The weight of visual attention changes with the length of the generated text sequence. (c) Cosine similarity of the attention scores between each layer and its previous layer. A value close to 1 indicates that the attention score distributions in the two layers are nearly identical. (d)Visual attention. High attention tokens decrease as layers deepen.
  • Figure 5: Overview of ST$^3$ framework. Preprocess first converts the input from various modalities into tokens. All tokens are concatenated and fed into the LLM's decoder layers. Progressive Visual Token Pruning (PVTP) gradually prunes away non-critical visual tokens throughout the entire decoding forward process. Top-right illustrates the details of Visual Token Pruning. It extracts the visual token attention from the attention matrix of the previous decoding layer, and selectively retains the TopK most important tokens. The predicted token is then used as the input for the next generation step. Visual Token Annealing (VTA) employs a cosine function to control the decay of visual token KV cache kvcache in each generation step.
  • ...and 7 more figures