Table of Contents
Fetching ...

Token Pooling in Vision Transformers

Dmitrii Marin, Jen-Hao Rick Chang, Anurag Ranjan, Anish Prabhu, Mohammad Rastegari, Oncel Tuzel

TL;DR

This work identifies that vision transformers’ primary computational burden lies in fully-connected layers and shows that softmax-attention behaves like a high-dimensional low-pass filter, creating redundant token representations. It introduces Token Pooling, a nonuniform, data-aware token downsampling operator that minimizes reconstruction error via clustering (K-Means, K-Medoids, and weighted variants), offering a superior cost-accuracy trade-off compared to prior grid- or score-based methods. Through extensive experiments on DeiT architectures with ImageNet-1k, Token Pooling achieves substantial FLOPs reductions while maintaining or improving top-1 accuracy, outperforming PoWER-BERT and DynamicViT across budgets. The method is simple to integrate and generalizable to other transformer-based vision models, suggesting broad applicability for efficient vision transformers and downstream tasks.

Abstract

Despite the recent success in many applications, the high computational requirements of vision transformers limit their use in resource-constrained settings. While many existing methods improve the quadratic complexity of attention, in most vision transformers, self-attention is not the major computation bottleneck, e.g., more than 80% of the computation is spent on fully-connected layers. To improve the computational complexity of all layers, we propose a novel token downsampling method, called Token Pooling, efficiently exploiting redundancies in the images and intermediate token representations. We show that, under mild assumptions, softmax-attention acts as a high-dimensional low-pass (smoothing) filter. Thus, its output contains redundancy that can be pruned to achieve a better trade-off between the computational cost and accuracy. Our new technique accurately approximates a set of tokens by minimizing the reconstruction error caused by downsampling. We solve this optimization problem via cost-efficient clustering. We rigorously analyze and compare to prior downsampling methods. Our experiments show that Token Pooling significantly improves the cost-accuracy trade-off over the state-of-the-art downsampling. Token Pooling is a simple and effective operator that can benefit many architectures. Applied to DeiT, it achieves the same ImageNet top-1 accuracy using 42% fewer computations.

Token Pooling in Vision Transformers

TL;DR

This work identifies that vision transformers’ primary computational burden lies in fully-connected layers and shows that softmax-attention behaves like a high-dimensional low-pass filter, creating redundant token representations. It introduces Token Pooling, a nonuniform, data-aware token downsampling operator that minimizes reconstruction error via clustering (K-Means, K-Medoids, and weighted variants), offering a superior cost-accuracy trade-off compared to prior grid- or score-based methods. Through extensive experiments on DeiT architectures with ImageNet-1k, Token Pooling achieves substantial FLOPs reductions while maintaining or improving top-1 accuracy, outperforming PoWER-BERT and DynamicViT across budgets. The method is simple to integrate and generalizable to other transformer-based vision models, suggesting broad applicability for efficient vision transformers and downstream tasks.

Abstract

Despite the recent success in many applications, the high computational requirements of vision transformers limit their use in resource-constrained settings. While many existing methods improve the quadratic complexity of attention, in most vision transformers, self-attention is not the major computation bottleneck, e.g., more than 80% of the computation is spent on fully-connected layers. To improve the computational complexity of all layers, we propose a novel token downsampling method, called Token Pooling, efficiently exploiting redundancies in the images and intermediate token representations. We show that, under mild assumptions, softmax-attention acts as a high-dimensional low-pass (smoothing) filter. Thus, its output contains redundancy that can be pruned to achieve a better trade-off between the computational cost and accuracy. Our new technique accurately approximates a set of tokens by minimizing the reconstruction error caused by downsampling. We solve this optimization problem via cost-efficient clustering. We rigorously analyze and compare to prior downsampling methods. Our experiments show that Token Pooling significantly improves the cost-accuracy trade-off over the state-of-the-art downsampling. Token Pooling is a simple and effective operator that can benefit many architectures. Applied to DeiT, it achieves the same ImageNet top-1 accuracy using 42% fewer computations.

Paper Structure

This paper contains 33 sections, 13 equations, 12 figures, 5 tables, 1 algorithm.

Figures (12)

  • Figure 1: (a) We propose Token Pooling, a novel token downsampling method, for visual transformers. (b) The proposed method achieves a state-of-the-art trade-off between accuracy and computation. (c) Token Pooling uses cluster analysis to aggregate information from individual tokens automatically. We show the input images and the token clusters at the 6-th layer of DeiT-S.
  • Figure 2: Score-based downsampling methods goyal2020powerRao2021DynamicViTEVvs. our method. In the figure, the x-axis represents the token values (in one dimension), and the y-axis represents their scores. Suppose four tokens are to be selected. (a) Score-based methods select tokens with higher scores. Since the scoring function is continuous, all tokens in the left lobe will be selected, resulting in redundancy and information loss in the right lobe. (b) The proposed method first forms four clusters to approximate the set of tokens, then selects the cluster centers. Thus, the output tokens are a more accurate representation of the original token set than the score-based methods.
  • Figure 3: Main results. (a) shows the accuracy when we apply different downsampling methods to DeiT-S. More is in \ref{['app:clustering ablations']}. (b) shows a comparison between the proposed method with the state-of-the-art downsampling methods. The results of our method and PoWER-BERT are acquired by varying ${\bm{K}}$ and the base architecture among DeiT-Ti, DeiT-e252, DeiT-e318, and DeiT-S.
  • Figure 4: The figure shows the results when we apply Token Pooling to various DeiT architectures. Token Pooling consistently improves the computation-accuracy trade-off for all evaluated architectures. By utilizing both Token Pooling and architecture search, we can further improve the accuracy at a given flops budget. For example, at 1 Gflop, we should use Token Pooling on DeiT-e252 instead of DeiT-S.
  • Figure 5: Ablation studies of (a) downsampling methods using significance score, and (b) proposed Token Pooling using different clustering algorithms. The base model is DeiT-S for all methods.
  • ...and 7 more figures