Table of Contents
Fetching ...

ToSA: Token Selective Attention for Efficient Vision Transformers

Manish Kumar Singh, Rajeev Yasarla, Hong Cai, Mingu Lee, Fatih Porikli

TL;DR

A novel token selective attention approach, ToSA, which can identify tokens that need to be attended as well as those that can skip a transformer layer, and which can significantly reduce computation costs while maintaining accuracy on the ImageNet classification benchmark.

Abstract

In this paper, we propose a novel token selective attention approach, ToSA, which can identify tokens that need to be attended as well as those that can skip a transformer layer. More specifically, a token selector parses the current attention maps and predicts the attention maps for the next layer, which are then used to select the important tokens that should participate in the attention operation. The remaining tokens simply bypass the next layer and are concatenated with the attended ones to re-form a complete set of tokens. In this way, we reduce the quadratic computation and memory costs as fewer tokens participate in self-attention while maintaining the features for all the image patches throughout the network, which allows it to be used for dense prediction tasks. Our experiments show that by applying ToSA, we can significantly reduce computation costs while maintaining accuracy on the ImageNet classification benchmark. Furthermore, we evaluate on the dense prediction task of monocular depth estimation on NYU Depth V2, and show that we can achieve similar depth prediction accuracy using a considerably lighter backbone with ToSA.

ToSA: Token Selective Attention for Efficient Vision Transformers

TL;DR

A novel token selective attention approach, ToSA, which can identify tokens that need to be attended as well as those that can skip a transformer layer, and which can significantly reduce computation costs while maintaining accuracy on the ImageNet classification benchmark.

Abstract

In this paper, we propose a novel token selective attention approach, ToSA, which can identify tokens that need to be attended as well as those that can skip a transformer layer. More specifically, a token selector parses the current attention maps and predicts the attention maps for the next layer, which are then used to select the important tokens that should participate in the attention operation. The remaining tokens simply bypass the next layer and are concatenated with the attended ones to re-form a complete set of tokens. In this way, we reduce the quadratic computation and memory costs as fewer tokens participate in self-attention while maintaining the features for all the image patches throughout the network, which allows it to be used for dense prediction tasks. Our experiments show that by applying ToSA, we can significantly reduce computation costs while maintaining accuracy on the ImageNet classification benchmark. Furthermore, we evaluate on the dense prediction task of monocular depth estimation on NYU Depth V2, and show that we can achieve similar depth prediction accuracy using a considerably lighter backbone with ToSA.
Paper Structure (10 sections, 5 equations, 4 figures, 2 tables)

This paper contains 10 sections, 5 equations, 4 figures, 2 tables.

Figures (4)

  • Figure 1: (a) Two consecutive standard transformer layers. (b) A standard transformer layer followed by our ToSA layer. Our proposed approach operates on a pair of consecutive transformer layers and modifies the second layer to be token selective. In other words, only a subset of tokens participate in the self-attention in the second layer while the rest bypass the layer; our token selector predicts importance scores of the tokens for selection. Note that we retain all the tokens throughout the layers to facilitate dense prediction tasks, e.g., segmentation, depth estimation.
  • Figure 2: Overview of our proposed ToSA approach. Based on an input $X_i$, a standard transformer layer generates an output $X_{i+1}$, as well as $QK^T$ maps (prior to Softmax), $B_i^1,B_i^2,...,B_i^H$, for the $H$ attention heads. Based on them, our token selector predicts the attention maps at the next layer and uses them to identify tokens that need to be attended, $X_{i+1}^{a,h}$, and tokens that can skip the next layer, $X_{i+1}^{p,h}$, for each head. Then, self-attention is performed on $X_{i+1}^{a,h}$, the output of which is concatenated with $X_{i+1}^{p,h}$ to form the output of this head, $X_{i+2}^h$. All $X_{i+2}^h$'s are combined to generate the final output of this layer, $X_{i+2}$. By reducing the number of tokens participating in attention, we significantly save computation and memory given the quadratic complexity of self-attention, while slightly improving accuracy.
  • Figure 3: Applying our proposed ToSA to all the pairs of transformer layers in a 12-layer DeiT.
  • Figure 4: Visualization of token selection at the 2nd, 6th, and 10th layers. The input is a cat image with the cat at the center. Dark colors indicate the patches that do not take part in self-attention.