Table of Contents
Fetching ...

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Zhou, Cho-Jui Hsieh

TL;DR

DynamicViT tackles the efficiency challenge of vision transformers by introducing dynamic token sparsification. It adds lightweight prediction modules that hierarchically prune tokens, guided by local and global token features, and trains with attention masking and multiple loss terms to preserve accuracy. The method achieves substantial speedups—reducing FLOPs by $31\%-37\%$ and boosting throughput by over $40\%$—with only ~0.5% accuracy loss across several backbones on ImageNet, and scales well to larger models. This token-level sparsification provides a hardware-friendly, data-dependent path to competitive accuracy/complexity trade-offs for vision transformers and related architectures.

Abstract

Attention is sparse in vision transformers. We observe the final prediction in vision transformers is only based on a subset of most informative tokens, which is sufficient for accurate image recognition. Based on this observation, we propose a dynamic token sparsification framework to prune redundant tokens progressively and dynamically based on the input. Specifically, we devise a lightweight prediction module to estimate the importance score of each token given the current features. The module is added to different layers to prune redundant tokens hierarchically. To optimize the prediction module in an end-to-end manner, we propose an attention masking strategy to differentiably prune a token by blocking its interactions with other tokens. Benefiting from the nature of self-attention, the unstructured sparse tokens are still hardware friendly, which makes our framework easy to achieve actual speed-up. By hierarchically pruning 66% of the input tokens, our method greatly reduces 31%~37% FLOPs and improves the throughput by over 40% while the drop of accuracy is within 0.5% for various vision transformers. Equipped with the dynamic token sparsification framework, DynamicViT models can achieve very competitive complexity/accuracy trade-offs compared to state-of-the-art CNNs and vision transformers on ImageNet. Code is available at https://github.com/raoyongming/DynamicViT

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

TL;DR

DynamicViT tackles the efficiency challenge of vision transformers by introducing dynamic token sparsification. It adds lightweight prediction modules that hierarchically prune tokens, guided by local and global token features, and trains with attention masking and multiple loss terms to preserve accuracy. The method achieves substantial speedups—reducing FLOPs by and boosting throughput by over —with only ~0.5% accuracy loss across several backbones on ImageNet, and scales well to larger models. This token-level sparsification provides a hardware-friendly, data-dependent path to competitive accuracy/complexity trade-offs for vision transformers and related architectures.

Abstract

Attention is sparse in vision transformers. We observe the final prediction in vision transformers is only based on a subset of most informative tokens, which is sufficient for accurate image recognition. Based on this observation, we propose a dynamic token sparsification framework to prune redundant tokens progressively and dynamically based on the input. Specifically, we devise a lightweight prediction module to estimate the importance score of each token given the current features. The module is added to different layers to prune redundant tokens hierarchically. To optimize the prediction module in an end-to-end manner, we propose an attention masking strategy to differentiably prune a token by blocking its interactions with other tokens. Benefiting from the nature of self-attention, the unstructured sparse tokens are still hardware friendly, which makes our framework easy to achieve actual speed-up. By hierarchically pruning 66% of the input tokens, our method greatly reduces 31%~37% FLOPs and improves the throughput by over 40% while the drop of accuracy is within 0.5% for various vision transformers. Equipped with the dynamic token sparsification framework, DynamicViT models can achieve very competitive complexity/accuracy trade-offs compared to state-of-the-art CNNs and vision transformers on ImageNet. Code is available at https://github.com/raoyongming/DynamicViT

Paper Structure

This paper contains 25 sections, 14 equations, 8 figures, 4 tables.

Figures (8)

  • Figure 1: Illustration of our main idea. CNN models usually leverage the structural downsampling strategy to build hierarchical architectures as shown in (a). unstructured and data-dependent downsampling method in (b) can better exploit the sparsity in the input data. Thanks to the nature of the self-attention operation, the unstructured token set is also easy to accelerate through parallel computing. (c) visualizes the impact of each spatial location on the final prediction in the DeiT-S model touvron2020deit using the visualization method proposed in chefer2020transformer. These results demonstrate the final prediction in vision transformers is only based on a subset of most informative tokens, which suggests a large proportion of tokens can be removed without hurting the performance.
  • Figure 2: The overall framework of the proposed approach. The proposed prediction module is inserted between the transformer blocks to selectively prune less informative token conditioned on features produced by the previous layer. By doing so, less tokens are processed in the followed layers.
  • Figure 3: Model complexity (FLOPs) and top-1 accuracy trade-offs on ImageNet. We compare DynamicViT with the state-of-the-art image classification models. Our models achieve better trade-offs compared to the various vision transformers as well as carefully designed CNN models.
  • Figure 4: Comparison of our dynamic token sparsification method with model width scaling. We train our DynamicViT based on DeiT models with embedding dimension varying from 192 to 384 and fix ratio $\rho=0.7$. We see dynamic token sparsification is more efficient than commonly used model width scaling.
  • Figure 5: Visualization of the progressively sparsified tokens. We show the original input image and the sparsification results after the three stages, where the masks represent the corresponding tokens are discarded. We see our method can gradually focus on the most representative regions in the image. This phenomenon suggests that the DynamicViT has better interpretability.
  • ...and 3 more figures