The Linear Attention Resurrection in Vision Transformer
Chuanyang Zheng
TL;DR
Vision Transformers incur $O(N^2C)$ complexity due to softmax attention, hindering high-resolution vision tasks. The authors introduce Enhanced Linear Attention with a Local Concentration Module (LCM) and assemble them into L$^2$ViT, a hierarchical backbone that alternates Linear Global Attention (LGA) with Local Window Attention (LWA) to achieve global and local modeling under linear complexity $O(NC^2)$. Key contributions include the non-negative attention property, the LCM for local concentration, and the LGA/LWA architectural design, which collectively improve performance on ImageNet-1K (84.4% Top-1; 87.0% after 22k pretraining at 384^2) and excel in COCO object detection and ADE20K semantic segmentation. The results demonstrate that linear attention, when augmented with locality-focused modules, can rival or surpass softmax-based ViTs, offering scalable, versatile vision backbones for a range of tasks.
Abstract
Vision Transformers (ViTs) have recently taken computer vision by storm. However, the softmax attention underlying ViTs comes with a quadratic complexity in time and memory, hindering the application of ViTs to high-resolution images. We revisit the attention design and propose a linear attention method to address the limitation, which doesn't sacrifice ViT's core advantage of capturing global representation like existing methods (e.g. local window attention of Swin). We further investigate the key difference between linear attention and softmax attention. Our empirical results suggest that linear attention lacks a fundamental property of concentrating the distribution of the attention matrix. Inspired by this observation, we introduce a local concentration module to enhance linear attention. By incorporating enhanced linear global attention and local window attention, we propose a new ViT architecture, dubbed L$^2$ViT. Notably, L$^2$ViT can effectively capture both global interactions and local representations while enjoying linear computational complexity. Extensive experiments demonstrate the strong performance of L$^2$ViT. On image classification, L$^2$ViT achieves 84.4% Top-1 accuracy on ImageNet-1K without any extra training data or label. By further pre-training on ImageNet-22k, it attains 87.0% when fine-tuned with resolution 384$^2$. For downstream tasks, L$^2$ViT delivers favorable performance as a backbone on object detection as well as semantic segmentation.
