Table of Contents
Fetching ...

Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth

Yihe Dong, Jean-Baptiste Cordonnier, Andreas Loukas

TL;DR

The paper investigates why attention-based transformers perform well by introducing a path decomposition that rewrites a self-attention network (SAN) as a sum over paths through the network. It proves that pure SANs, without skip connections or MLPs, exhibit a token-uniformity bias that drives the output to a rank-1 matrix at a doubly exponential rate with depth; skip connections and MLPs counteract this collapse to varying degrees. Empirical results on standard architectures (BERT, ALBERT, XLNet) confirm rank collapse in the absence of skips and illustrate the bias via toy visualizations and path-length analyses. The work provides both a theoretical framework and practical insights, highlighting the critical role of architectural components in maintaining expressive power and suggesting avenues for long-path utilization and width-depth tradeoffs in future models.

Abstract

Attention-based architectures have become ubiquitous in machine learning, yet our understanding of the reasons for their effectiveness remains limited. This work proposes a new way to understand self-attention networks: we show that their output can be decomposed into a sum of smaller terms, each involving the operation of a sequence of attention heads across layers. Using this decomposition, we prove that self-attention possesses a strong inductive bias towards "token uniformity". Specifically, without skip connections or multi-layer perceptrons (MLPs), the output converges doubly exponentially to a rank-1 matrix. On the other hand, skip connections and MLPs stop the output from degeneration. Our experiments verify the identified convergence phenomena on different variants of standard transformer architectures.

Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth

TL;DR

The paper investigates why attention-based transformers perform well by introducing a path decomposition that rewrites a self-attention network (SAN) as a sum over paths through the network. It proves that pure SANs, without skip connections or MLPs, exhibit a token-uniformity bias that drives the output to a rank-1 matrix at a doubly exponential rate with depth; skip connections and MLPs counteract this collapse to varying degrees. Empirical results on standard architectures (BERT, ALBERT, XLNet) confirm rank collapse in the absence of skips and illustrate the bias via toy visualizations and path-length analyses. The work provides both a theoretical framework and practical insights, highlighting the critical role of architectural components in maintaining expressive power and suggesting avenues for long-path utilization and width-depth tradeoffs in future models.

Abstract

Attention-based architectures have become ubiquitous in machine learning, yet our understanding of the reasons for their effectiveness remains limited. This work proposes a new way to understand self-attention networks: we show that their output can be decomposed into a sum of smaller terms, each involving the operation of a sequence of attention heads across layers. Using this decomposition, we prove that self-attention possesses a strong inductive bias towards "token uniformity". Specifically, without skip connections or multi-layer perceptrons (MLPs), the output converges doubly exponentially to a rank-1 matrix. On the other hand, skip connections and MLPs stop the output from degeneration. Our experiments verify the identified convergence phenomena on different variants of standard transformer architectures.

Paper Structure

This paper contains 27 sections, 7 theorems, 59 equations, 8 figures.

Key Result

Theorem 2.1

The output of a depth $L$ self-attention network with $H$ heads per layer (including biases and skip connections) is given by where $\bm{P}_\textit{path} = \bm{P}_{h_L}^{L} \cdots \bm{P}_{h_1}^{1}$ is an input-dependent stochastic matrix, whereas $\bm{W}_{\textit{path}} = \bm{W}_{h_1}^{1} \cdots \bm{W}_{h_L}^{L}$ and $\bm{b}$ do not depend on the input.

Figures (8)

  • Figure 1: Two paths in a deep Self-Attention Network (SAN) with $H$ heads and $L$ layers. At each layer, a path can go through one of the heads or bypass the layer. Adding an MLP block after each attention layer forms the transformer architecture.
  • Figure 3: Applying a trained single-layer transformer module recurrently, to models of increasing hidden dimension (horizontal direction) and across architectural variants (vertical direction). The two light background paths illustrate the two training trajectories, for which the starting points are $(-0.3, 0)$ and $(0.3, 0)$. Each figure contains the same number of steps. Consistent with the theory in §\ref{['sec-counter-conv']}, convergence slows down or stops as the dimension increases (since $\beta \ge \|\bm{W}_{QK}^l\|_1 \|\bm{W}_{V}^{l}\|_{1,\infty}$ is generally larger), as well as when either MLP or skip connections are added.
  • Figure 4: To determine how much of the expressive power can be attributed to short vs long paths, we examine the performance of subsets of paths of different lengths (rather than of the entire SAN). Performance can be seen to consistently deteriorate with respect to path length, supporting our hypothesis that short paths are responsible for the majority of the expressive power.
  • Figure 5: Distribution of the path length for a diverse selection of transformer architectures (encoder only) with different depths and widths. The legends are sorted by the total number of heads in the architecture L$\times$H. We provide the following architecture: GPT-3 gpt3, T5 2020t5, Bert devlin2018bert, ViT dosovitskiy2021an, DistilBert distilbert, MobileBert mobilebert.
  • Figure : (a) Bert
  • ...and 3 more figures

Theorems & Definitions (15)

  • Theorem 2.1: Path decomposition of SAN
  • proof
  • Theorem 2.2: Simplified
  • Theorem 2.3: Simplified
  • Claim 3.1
  • Corollary 3.2: Simplified
  • Lemma A.1
  • Lemma A.2
  • proof
  • proof
  • ...and 5 more