Table of Contents
Fetching ...

The Shaped Transformer: Attention Models in the Infinite Depth-and-Width Limit

Lorenzo Noci, Chuning Li, Mufan Bill Li, Bobby He, Thomas Hofmann, Chris Maddison, Daniel M. Roy

TL;DR

This work addresses the instability and rank-degeneracy (rank-collapse) observed in Softmax-based attention within deep Transformer architectures. It introduces the shaped attention mechanism, combining centering around the identity and a width-dependent temperature, to stabilize forward and backward covariance propagation in the proportional infinite-depth-and-width limit. The authors derive neural covariance SDEs that characterize the initial distribution for shaped Attention and Shaped Transformer blocks, and demonstrate local convergence (in the Skorohod sense) to these SDEs, with explicit drift and diffusion terms encoding the influence of attention and residual connections. Simulations show the SDE descriptions closely track finite-size networks, and preliminary experiments suggest shaped Transformers can train with competitive stability and performance. Overall, the paper provides a tractable, non-commutative limiting theory for Transformer-like architectures, offering design principles and hyperparameter guidance for stable, scalable deep attention models.

Abstract

In deep learning theory, the covariance matrix of the representations serves as a proxy to examine the network's trainability. Motivated by the success of Transformers, we study the covariance matrix of a modified Softmax-based attention model with skip connections in the proportional limit of infinite-depth-and-width. We show that at initialization the limiting distribution can be described by a stochastic differential equation (SDE) indexed by the depth-to-width ratio. To achieve a well-defined stochastic limit, the Transformer's attention mechanism is modified by centering the Softmax output at identity, and scaling the Softmax logits by a width-dependent temperature parameter. We examine the stability of the network through the corresponding SDE, showing how the scale of both the drift and diffusion can be elegantly controlled with the aid of residual connections. The existence of a stable SDE implies that the covariance structure is well-behaved, even for very large depth and width, thus preventing the notorious issues of rank degeneracy in deep attention models. Finally, we show, through simulations, that the SDE provides a surprisingly good description of the corresponding finite-size model. We coin the name shaped Transformer for these architectural modifications.

The Shaped Transformer: Attention Models in the Infinite Depth-and-Width Limit

TL;DR

This work addresses the instability and rank-degeneracy (rank-collapse) observed in Softmax-based attention within deep Transformer architectures. It introduces the shaped attention mechanism, combining centering around the identity and a width-dependent temperature, to stabilize forward and backward covariance propagation in the proportional infinite-depth-and-width limit. The authors derive neural covariance SDEs that characterize the initial distribution for shaped Attention and Shaped Transformer blocks, and demonstrate local convergence (in the Skorohod sense) to these SDEs, with explicit drift and diffusion terms encoding the influence of attention and residual connections. Simulations show the SDE descriptions closely track finite-size networks, and preliminary experiments suggest shaped Transformers can train with competitive stability and performance. Overall, the paper provides a tractable, non-commutative limiting theory for Transformer-like architectures, offering design principles and hyperparameter guidance for stable, scalable deep attention models.

Abstract

In deep learning theory, the covariance matrix of the representations serves as a proxy to examine the network's trainability. Motivated by the success of Transformers, we study the covariance matrix of a modified Softmax-based attention model with skip connections in the proportional limit of infinite-depth-and-width. We show that at initialization the limiting distribution can be described by a stochastic differential equation (SDE) indexed by the depth-to-width ratio. To achieve a well-defined stochastic limit, the Transformer's attention mechanism is modified by centering the Softmax output at identity, and scaling the Softmax logits by a width-dependent temperature parameter. We examine the stability of the network through the corresponding SDE, showing how the scale of both the drift and diffusion can be elegantly controlled with the aid of residual connections. The existence of a stable SDE implies that the covariance structure is well-behaved, even for very large depth and width, thus preventing the notorious issues of rank degeneracy in deep attention models. Finally, we show, through simulations, that the SDE provides a surprisingly good description of the corresponding finite-size model. We coin the name shaped Transformer for these architectural modifications.
Paper Structure (36 sections, 18 theorems, 125 equations, 8 figures, 2 tables)

This paper contains 36 sections, 18 theorems, 125 equations, 8 figures, 2 tables.

Key Result

Theorem 3.2

Let $X_\ell$ be the hidden layers of a ResNet defined in eq:resnet with $\lambda^2 + \gamma^2 = 1$, where both $\lambda$ and $\gamma$ do not depend on $d,n$. Then the feature covariance $V_\ell$ converges to the solution of the following SDE (in the sense of defn:conv_cov) where $b_{\text{res}}(V) = \gamma^2 b_{\text{ReLU}}(V) = \gamma^2 [\nu(\rho^{\alpha\beta}) \sqrt{ V^{\alpha\alpha} V^{\beta\b

Figures (8)

  • Figure 1: Our shaped Transformer prevents token representations from becoming perfectly aligned, i.e. rank collapse. Left: mean correlation $\rho^{\alpha\beta}_\ell$ of Transformers (\ref{['eq:shaped_transformer']}) with and without shaped attention (\ref{['eq:shaped-attention']}) and Pre-LN xiong2020layer. Right: kernel density estimate and histogram of correlations from covariance SDE in \ref{['thm:sde-attention']} and shaped attention NN. Here we note correlation converging to $1$ implies a poorly conditioned covariance matrix. Simulated with $n = 200, d = 150, \gamma = 1/\sqrt{8}, \tau_0 = 1, \rho^{\alpha\beta}_0 = 0.2$, SDE step size $0.01$, and $2^{12}$ samples.
  • Figure 2: Comparing gradients norms at initialization for different parameters as a function of depth, with and without shaped attention. The architecture is the same as in \ref{['fig:rho_path_density']} but with autoregressive causal masking, and the task is next-token prediction on code data. Left: Value weights $W^V_{\ell}$ for shaped attention, standard Pre-LN, and the original Post-LN block vaswani2017attention. Right: the same gradient norm plot but for Query weights $W^Q_l$. We find that shaping the attention mechanism successfully prevents gradients from vanishing, while unshaped Transformers suffer from rapidly vanishing gradients. Interestingly, only the Post-LN query gradients vanish, but value gradients are stable across depths, which is consistent with the findings of noci2022signal. On the other hand, shaped attention has stable gradients for both parameters inside and outside the Softmax nonlinearity.
  • Figure 3: Left: Kernel density estimates of correlation $\rho^{\alpha\beta}_d$ for various values of the residual strength parameter $\gamma$. In particular, $\gamma=1$ recovers a shaped-ReLU MLP without skip connections, and $\gamma = 1/\sqrt{d}$ is the setting studied in noci2022signalhayou2023width. Solid lines represent finite networks, while our SDE simulations are presented in dashed lines. Right: 95th percentile of the absolute value of the correlation distribution as a function of $\gamma$. Note reducing $\gamma$ reduces the concentration around $\rho^{\alpha\beta} = 1$, and our SDE reliably approximates finite size networks. Simulated with $n = 300, d = 100, \rho^{\alpha\beta}_0 = 0.2, c_+ = 0, c_-=-1$, and $2^{13}$ samples.
  • Figure 4: Mean correlation (left) and covariance (right) (in absolute value) under various interventions on the proposed shaped attention. In particular, we remove either one or two of the three modifications from the shaped attention in \ref{['eq:shaped-attention']}. For instance "$\tau^2=nn_k$, center" indicates that we use the proposed temperature, and we center by $m^{-1}$, but we do not add the identity matrix, while in "only id" we add the identity matrix but use $\tau=\sqrt{n_k}$ and do not center. We note in this "only id" case, the covariance remains unstable due to incorrect scaling. Due to exploding covariance, we choose to not include the cases "id, $\tau^2=nn_k$" and "only id" in the correlation plot (but only in the covariance plot). Simulated with $n = 300, d = 150, \rho^{\alpha\beta}_0 = 0.2$, $\gamma=1/\sqrt{2}$ and $2^{13}$ samples.
  • Figure 5: Left: Trajectories of the maximum eigenvalue of the covariance matrix in a shaped attention network, with adversarially large initial condition. Right: Stopping time of the shaped attention neural network, capped at 1. Stopping time is defined as $t^* = d^*/n$ with $d^*$ the maximum depth beyond which one of the eigenvalues of the covariance matrix exceeds $10^4$ or drops below $10^{-4}$. Simulated with $n=d=200$, $\tau_0=1$, and $100$ samples used for median and $10$th percentile.
  • ...and 3 more figures

Theorems & Definitions (35)

  • Definition 3.1: Convergence of Covariance
  • Theorem 3.2
  • Definition 4.1: Local Convergence
  • Theorem 4.2
  • Corollary 4.2: Shaped Transformer Covariance SDE
  • Definition A.1
  • Proposition A.2: Convergence of Markov Chains to SDE, Proposition A.6, li2022neural
  • Lemma B.1
  • proof
  • Lemma B.2
  • ...and 25 more