Table of Contents
Fetching ...

Learning Spatial Decay for Vision Transformers

Yuxin Mao, Zhen Qin, Jinxing Zhou, Bin Fan, Jing Zhang, Yiran Zhong, Yuchao Dai

TL;DR

This work tackles the lack of 2D spatial inductive bias in Vision Transformers by introducing Spatial Decay Transformer (SDT) with Context-Aware Gating (CAG), enabling data-dependent spatial attention. It unifies fixed 2D priors with learned content representations through a spatial-content fusion framework and provides a decomposed 2D implementation for efficiency in high-resolution stages. Empirically, SDT improves performance over data-independent decay on ImageNet-1K classification and achieves competitive image-generation FID scores, while ablations confirm the superiority of content-aware gating and CDSF over 1D or fixed-decay variants. The results establish data-dependent spatial decay as a new paradigm for spatial attention in vision transformers, with potential applications to spatial-temporal vision tasks.

Abstract

Vision Transformers (ViTs) have revolutionized computer vision, yet their self-attention mechanism lacks explicit spatial inductive biases, leading to suboptimal performance on spatially-structured tasks. Existing approaches introduce data-independent spatial decay based on fixed distance metrics, applying uniform attention weighting regardless of image content and limiting adaptability to diverse visual scenarios. Inspired by recent advances in large language models where content-aware gating mechanisms (e.g., GLA, HGRN2, FOX) significantly outperform static alternatives, we present the first successful adaptation of data-dependent spatial decay to 2D vision transformers. We introduce \textbf{Spatial Decay Transformer (SDT)}, featuring a novel Context-Aware Gating (CAG) mechanism that generates dynamic, data-dependent decay for patch interactions. Our approach learns to modulate spatial attention based on both content relevance and spatial proximity. We address the fundamental challenge of 1D-to-2D adaptation through a unified spatial-content fusion framework that integrates manhattan distance-based spatial priors with learned content representations. Extensive experiments on ImageNet-1K classification and generation tasks demonstrate consistent improvements over strong baselines. Our work establishes data-dependent spatial decay as a new paradigm for enhancing spatial attention in vision transformers.

Learning Spatial Decay for Vision Transformers

TL;DR

This work tackles the lack of 2D spatial inductive bias in Vision Transformers by introducing Spatial Decay Transformer (SDT) with Context-Aware Gating (CAG), enabling data-dependent spatial attention. It unifies fixed 2D priors with learned content representations through a spatial-content fusion framework and provides a decomposed 2D implementation for efficiency in high-resolution stages. Empirically, SDT improves performance over data-independent decay on ImageNet-1K classification and achieves competitive image-generation FID scores, while ablations confirm the superiority of content-aware gating and CDSF over 1D or fixed-decay variants. The results establish data-dependent spatial decay as a new paradigm for spatial attention in vision transformers, with potential applications to spatial-temporal vision tasks.

Abstract

Vision Transformers (ViTs) have revolutionized computer vision, yet their self-attention mechanism lacks explicit spatial inductive biases, leading to suboptimal performance on spatially-structured tasks. Existing approaches introduce data-independent spatial decay based on fixed distance metrics, applying uniform attention weighting regardless of image content and limiting adaptability to diverse visual scenarios. Inspired by recent advances in large language models where content-aware gating mechanisms (e.g., GLA, HGRN2, FOX) significantly outperform static alternatives, we present the first successful adaptation of data-dependent spatial decay to 2D vision transformers. We introduce \textbf{Spatial Decay Transformer (SDT)}, featuring a novel Context-Aware Gating (CAG) mechanism that generates dynamic, data-dependent decay for patch interactions. Our approach learns to modulate spatial attention based on both content relevance and spatial proximity. We address the fundamental challenge of 1D-to-2D adaptation through a unified spatial-content fusion framework that integrates manhattan distance-based spatial priors with learned content representations. Extensive experiments on ImageNet-1K classification and generation tasks demonstrate consistent improvements over strong baselines. Our work establishes data-dependent spatial decay as a new paradigm for enhancing spatial attention in vision transformers.

Paper Structure

This paper contains 13 sections, 13 equations, 4 figures, 5 tables.

Figures (4)

  • Figure 1: The network structure of the Spatial Decay Layer. The attention weights are modulated by a learned decay map $\mathbf{M}_{\text{decay}}$ computed from $\mathbf{G}$, enabling spatially adaptive attention.
  • Figure 2: Overall architecture of the proposed Learnable Spatial Decay based Vision Transformer. The model consists of four stages of Spatial Decay Transformer (SDT) Blocks, and each stage consists of several Spatial Decay Layers as shown in Fig. \ref{['arch']}.
  • Figure 3: Training loss comparison between our proposed SDT-H-T and RMT-T. The blue curve represents our SDT-H-T, and the orange curve represents RMT-T.
  • Figure 4: Scaling up the SDT-P enhances the FID during every iterations of training. We present the FID-50K across training iterations for four SDT-P models. Enhancing the SDT-P backbone results in improved generative models for all sizes of models.