Table of Contents
Fetching ...

On the Role of Attention Masks and LayerNorm in Transformers

Xinyi Wu, Amir Ajorlou, Yifei Wang, Stefanie Jegelka, Ali Jadbabaie

TL;DR

This work analyzes how attention masks and LayerNorm affect rank-collapse in transformer self-attention. By treating SANs as discrete-time dynamics on a masked graph, it proves exponential collapse under broad masked patterns without LayerNorm, while local masks slow collapse. Introducing LayerNorm reveals a nuanced picture: LN can preserve higher-rank equilibria and, with suitable value matrices, prevent universal rank-1 collapse for generic inputs, contradicting prior claims about LN's irrelevance to rank collapse. The results are supported by numerical experiments showing that real models exhibit full-rank token representations that are nonetheless anisotropic, suggesting a delicate balance between expressivity and geometric structure. Overall, the paper highlights the critical roles of masking topology and LayerNorm in transformer dynamics and points to new directions for design and analysis of attention mechanisms.

Abstract

Self-attention is the key mechanism of transformers, which are the essential building blocks of modern foundation models. Recent studies have shown that pure self-attention suffers from an increasing degree of rank collapse as depth increases, limiting model expressivity and further utilization of model depth. The existing literature on rank collapse, however, has mostly overlooked other critical components in transformers that may alleviate the rank collapse issue. In this paper, we provide a general analysis of rank collapse under self-attention, taking into account the effects of attention masks and layer normalization (LayerNorm). In particular, we find that although pure masked attention still suffers from exponential collapse to a rank one subspace, sparse or local masked attention can provably slow down the collapse rate. In the case of self-attention with LayerNorm, we first show that for certain classes of value matrices, collapse to a rank one subspace still happens exponentially. However, through construction of nontrivial counterexamples, we then establish that with proper choice of value matrices, a general class of sequences may not converge to a rank one subspace, and the self-attention dynamics with LayerNorm can simultaneously possess a rich set of equilibria with any possible rank between one and full. Our result refutes the previous hypothesis that LayerNorm plays no role in the rank collapse of self-attention and suggests that self-attention with LayerNorm constitutes a much more expressive, versatile nonlinear dynamical system than what was originally thought.

On the Role of Attention Masks and LayerNorm in Transformers

TL;DR

This work analyzes how attention masks and LayerNorm affect rank-collapse in transformer self-attention. By treating SANs as discrete-time dynamics on a masked graph, it proves exponential collapse under broad masked patterns without LayerNorm, while local masks slow collapse. Introducing LayerNorm reveals a nuanced picture: LN can preserve higher-rank equilibria and, with suitable value matrices, prevent universal rank-1 collapse for generic inputs, contradicting prior claims about LN's irrelevance to rank collapse. The results are supported by numerical experiments showing that real models exhibit full-rank token representations that are nonetheless anisotropic, suggesting a delicate balance between expressivity and geometric structure. Overall, the paper highlights the critical roles of masking topology and LayerNorm in transformer dynamics and points to new directions for design and analysis of attention mechanisms.

Abstract

Self-attention is the key mechanism of transformers, which are the essential building blocks of modern foundation models. Recent studies have shown that pure self-attention suffers from an increasing degree of rank collapse as depth increases, limiting model expressivity and further utilization of model depth. The existing literature on rank collapse, however, has mostly overlooked other critical components in transformers that may alleviate the rank collapse issue. In this paper, we provide a general analysis of rank collapse under self-attention, taking into account the effects of attention masks and layer normalization (LayerNorm). In particular, we find that although pure masked attention still suffers from exponential collapse to a rank one subspace, sparse or local masked attention can provably slow down the collapse rate. In the case of self-attention with LayerNorm, we first show that for certain classes of value matrices, collapse to a rank one subspace still happens exponentially. However, through construction of nontrivial counterexamples, we then establish that with proper choice of value matrices, a general class of sequences may not converge to a rank one subspace, and the self-attention dynamics with LayerNorm can simultaneously possess a rich set of equilibria with any possible rank between one and full. Our result refutes the previous hypothesis that LayerNorm plays no role in the rank collapse of self-attention and suggests that self-attention with LayerNorm constitutes a much more expressive, versatile nonlinear dynamical system than what was originally thought.
Paper Structure (60 sections, 19 theorems, 132 equations, 9 figures)

This paper contains 60 sections, 19 theorems, 132 equations, 9 figures.

Key Result

Theorem 1

Consider the self-attention dynamics without LayerNorm defined in eq: update_no_LN. Under A1-A3, if $\mathcal{G}$ is a quasi-strongly connected graph, then there exists $\epsilon > 0$ where for all $t\geq 0$, As a result, a rank collapse of tokens happens exponentially with respect to $\mu(\cdot)$, i.e. there exists $C > 0$ such that where $r$ is the radius of $\mathcal{G}$, meaning that tokens

Figures (9)

  • Figure 1: Long-term behavior of tokens in the case of $N=2, d=2$. Without LayerNorm (left), both tokens collapse to the same point in $\mathbb{R}^{2}$; whereas with LayerNorm (right), such a collapse would not necessarily happen and token representations can maintain full rank in the long term (first token converges either to $(0,1)$ or $(0,-1)$. Assuming convergence to $(0,1)$ for the first token, the second token converges to $B$, if it is initially located within the red segment).
  • Figure 2: Evolution of $\mu(X^{(t)})$ (in log-log scale) as the number of layers increases. Rank collapse happens exponentially for pure attention, despite different attention masks having different convergence rates. However, as soon as we solely add in LayerNorm, $\mu(X^{(t)})$ no longer converge to zero in randomly initialized models; in pretrained models, LayerNorm helps prevent the issue together with other components and stabilize the representations.
  • Figure 3: Evolution of $\mu(X^{(t)})$ (in log-log scale) as the number of layers increases. Smaller temperature terms alleviate the rate of rank collapse, and effect is more significant with global attention than with sparser masked attention, and more in shallower layers than deeper layers.
  • Figure 4: Evolution of token geometry as the number of layers increases. We see that tokens are indeed able to maintain full rank, while at the same time the representations are anisotropic, meaning that they concentrate in a narrow region, as indicated by the average pairwise absolute cosine similarities.
  • Figure 5: Convergence analysis of the second token in $d=2$.
  • ...and 4 more figures

Theorems & Definitions (39)

  • Definition 1: Reachability
  • Definition 2: Strongly Connected
  • Definition 3: Center Node
  • Definition 4: Quasi-Strongly Connected
  • Definition 5: Radius
  • Theorem 1
  • Remark 1
  • Theorem 2
  • Corollary 1
  • Remark 2
  • ...and 29 more