Table of Contents
Fetching ...

Saliency-driven Dynamic Token Pruning for Large Language Models

Yao Tao, Yehui Tang, Yun Wang, Mingjian Zhu, Hailin Hu, Yunhe Wang

TL;DR

A novel token pruning framework, namely Saliency-driven Dynamic Token Pruning (SDTP), to gradually and dynamically prune redundant tokens based on the input context, which greatly reduces 33% of the input tokens and achieves speedup up to 1.75 during inference, while maintaining comparable performance.

Abstract

Despite the recent success of large language models (LLMs), LLMs are particularly challenging in long-sequence inference scenarios due to the quadratic computational complexity of the attention mechanism. Inspired by the interpretability theory of feature attribution in neural network models, we observe that not all tokens have the same contribution. Based on this observation, we propose a novel token pruning framework, namely Saliency-driven Dynamic Token Pruning (SDTP), to gradually and dynamically prune redundant tokens based on the input context. Specifically, a lightweight saliency-driven prediction module is designed to estimate the importance score of each token with its hidden state, which is added to different layers of the LLM to hierarchically prune redundant tokens. Furthermore, a ranking-based optimization strategy is proposed to minimize the ranking divergence of the saliency score and the predicted importance score. Extensive experiments have shown that our framework is generalizable to various models and datasets. By hierarchically pruning 65\% of the input tokens, our method greatly reduces 33\% $\sim$ 47\% FLOPs and achieves speedup up to 1.75$\times$ during inference, while maintaining comparable performance. We further demonstrate that SDTP can be combined with KV cache compression method for further compression.

Saliency-driven Dynamic Token Pruning for Large Language Models

TL;DR

A novel token pruning framework, namely Saliency-driven Dynamic Token Pruning (SDTP), to gradually and dynamically prune redundant tokens based on the input context, which greatly reduces 33% of the input tokens and achieves speedup up to 1.75 during inference, while maintaining comparable performance.

Abstract

Despite the recent success of large language models (LLMs), LLMs are particularly challenging in long-sequence inference scenarios due to the quadratic computational complexity of the attention mechanism. Inspired by the interpretability theory of feature attribution in neural network models, we observe that not all tokens have the same contribution. Based on this observation, we propose a novel token pruning framework, namely Saliency-driven Dynamic Token Pruning (SDTP), to gradually and dynamically prune redundant tokens based on the input context. Specifically, a lightweight saliency-driven prediction module is designed to estimate the importance score of each token with its hidden state, which is added to different layers of the LLM to hierarchically prune redundant tokens. Furthermore, a ranking-based optimization strategy is proposed to minimize the ranking divergence of the saliency score and the predicted importance score. Extensive experiments have shown that our framework is generalizable to various models and datasets. By hierarchically pruning 65\% of the input tokens, our method greatly reduces 33\% 47\% FLOPs and achieves speedup up to 1.75 during inference, while maintaining comparable performance. We further demonstrate that SDTP can be combined with KV cache compression method for further compression.

Paper Structure

This paper contains 24 sections, 8 equations, 5 figures, 6 tables.

Figures (5)

  • Figure 1: Gradient-based saliency scores.
  • Figure 2: Salient token sparsity.
  • Figure 3: The overall architecture of our proposed method. Our proposed token pruning module is inserted between the transformer blocks and learns to decide the tokens to be pruned. The pruned tokens lead to less computation demand in the following layers.
  • Figure 4: Effect of first pruning layer.
  • Figure 5: Effect of multi-stage pruning.