Table of Contents
Fetching ...

Reinforcement Learning-based Token Pruning in Vision Transformers: A Markov Game Approach

Chenglong Lu, Shen Liang, Xuewei Wang, Wei Wang

TL;DR

This work tackles the high computational cost of Vision Transformers by learning a data-adaptive token pruning policy using reinforcement learning. It formulates token pruning as a sequential decision process and deploys MAPPO-based pruning layers after each Transformer block, coordinated through a Markov Game to preserve inter-layer dependencies. Key contributions include a MAPPO token-pruning architecture with per-token agents, a curated reward design balancing efficiency and accuracy, and a Markov Game trajectory that captures cross-layer dynamics; the method achieves up to 44% faster inference on ImageNet-1k with only about 0.4% accuracy loss (reducible to ~0.1% with fine-tuning). The approach demonstrates superior efficiency–accuracy trade-offs compared with state-of-the-art token pruning methods, highlighting the practical potential of RL-driven adaptivity in ViTs.

Abstract

Vision Transformers (ViTs) have computational costs scaling quadratically with the number of tokens, calling for effective token pruning policies. Most existing policies are handcrafted, lacking adaptivity to varying inputs. Moreover, they fail to consider the sequential nature of token pruning across multiple layers. In this work, for the first time (as far as we know), we exploit Reinforcement Learning (RL) to data-adaptively learn a pruning policy. Formulating token pruning as a sequential decision-making problem, we model it as a Markov Game and utilize Multi-Agent Proximal Policy Optimization (MAPPO) where each agent makes an individualized pruning decision for a single token. We also develop reward functions that enable simultaneous collaboration and competition of these agents to balance efficiency and accuracy. On the well-known ImageNet-1k dataset, our method improves the inference speed by up to 44% while incurring only a negligible accuracy drop of 0.4%. The source code is available at https://github.com/daashuai/rl4evit.

Reinforcement Learning-based Token Pruning in Vision Transformers: A Markov Game Approach

TL;DR

This work tackles the high computational cost of Vision Transformers by learning a data-adaptive token pruning policy using reinforcement learning. It formulates token pruning as a sequential decision process and deploys MAPPO-based pruning layers after each Transformer block, coordinated through a Markov Game to preserve inter-layer dependencies. Key contributions include a MAPPO token-pruning architecture with per-token agents, a curated reward design balancing efficiency and accuracy, and a Markov Game trajectory that captures cross-layer dynamics; the method achieves up to 44% faster inference on ImageNet-1k with only about 0.4% accuracy loss (reducible to ~0.1% with fine-tuning). The approach demonstrates superior efficiency–accuracy trade-offs compared with state-of-the-art token pruning methods, highlighting the practical potential of RL-driven adaptivity in ViTs.

Abstract

Vision Transformers (ViTs) have computational costs scaling quadratically with the number of tokens, calling for effective token pruning policies. Most existing policies are handcrafted, lacking adaptivity to varying inputs. Moreover, they fail to consider the sequential nature of token pruning across multiple layers. In this work, for the first time (as far as we know), we exploit Reinforcement Learning (RL) to data-adaptively learn a pruning policy. Formulating token pruning as a sequential decision-making problem, we model it as a Markov Game and utilize Multi-Agent Proximal Policy Optimization (MAPPO) where each agent makes an individualized pruning decision for a single token. We also develop reward functions that enable simultaneous collaboration and competition of these agents to balance efficiency and accuracy. On the well-known ImageNet-1k dataset, our method improves the inference speed by up to 44% while incurring only a negligible accuracy drop of 0.4%. The source code is available at https://github.com/daashuai/rl4evit.

Paper Structure

This paper contains 14 sections, 5 equations, 5 figures, 2 tables, 1 algorithm.

Figures (5)

  • Figure 1: Comparisons between our RL4EViT and some state-of-the-art token pruning methods on the ImageNet-1k dataset imagenet, using DeiT-S and DeiT-B as backbones. Our RL4EViT has achieved the best trade-off between accuracy and inference speed.
  • Figure 2: An illustration of our Reinforcement Learning for Efficient Vision Transformers (RL4EViT) method for token pruning. For a given ViT, each Transformer block is proceeded by a token pruning layer, where we utilize Multi-Agent Proximal Policy Optimization (MAPPO) Yu2022 to prune the output tokens of the Transformer block. The global reward across all token decision layers is obtained via a Markov Game markovgame1994 to preserve important sequential information. Note that in practice, instead of directly removing the pruned tokens, we set their attention weights to -1000 to facilitate mini-batch-based training.
  • Figure 3: Relationship between the proportion of the preserved tokens, as well as the model complexity, and $\alpha / \beta$, with DeiT-B as the backbone ViT.
  • Figure 4: Top-1 accuracies with and without ViT fine-tuning
  • Figure 5: Visualization of token pruning results obtained via our RL4EViT and DynamicVit rao2021dynamicvit, after Transformer blocks 3, 6, and 9 in DeiT-B touvron2021training. The pruned tokens are shown in white. As is indicated, thanks to its more data-adaptive nature, our RL4EViT can conduct most of the pruning at the earliest possible stage (block 3), leaving only the most important tokens behind. By contrast, DynamicViT can only prune a fixed number of tokens in each block, which limits its flexibility and thus its performance.