Table of Contents
Fetching ...

On the Interpolation Error of Nonlinear Attention versus Linear Regression

Zhenyu Liao, Jiaqing Liu, TianQi Hou, Difan Zou, Zenan Ling

TL;DR

It is shown that nonlinear Attention generally incurs a larger interpolation error than linear regression on random inputs, but this gap vanishes, and can even be reversed, when the input contains a structured signal, particularly if the Attention weights align with the signal direction.

Abstract

Attention has become the core building block of modern machine learning (ML) by efficiently capturing the long-range dependencies among input tokens. Its inherently parallelizable structure allows for efficient performance scaling with the rapidly increasing size of both data and model parameters. Despite its central role, the theoretical understanding of Attention, especially in the nonlinear setting, is progressing at a more modest pace. This paper provides a precise characterization of the interpolation error for a nonlinear Attention, in the high-dimensional regime where the number of input tokens $n$ and the embedding dimension $p$ are both large and comparable. Under a signal-plus-noise data model and for fixed Attention weights, we derive explicit (limiting) expressions for the mean-squared interpolation error. Leveraging recent advances in random matrix theory, we show that nonlinear Attention generally incurs a larger interpolation error than linear regression on random inputs. However, this gap vanishes, and can even be reversed, when the input contains a structured signal, particularly if the Attention weights align with the signal direction. Our theoretical insights are supported by numerical experiments.

On the Interpolation Error of Nonlinear Attention versus Linear Regression

TL;DR

It is shown that nonlinear Attention generally incurs a larger interpolation error than linear regression on random inputs, but this gap vanishes, and can even be reversed, when the input contains a structured signal, particularly if the Attention weights align with the signal direction.

Abstract

Attention has become the core building block of modern machine learning (ML) by efficiently capturing the long-range dependencies among input tokens. Its inherently parallelizable structure allows for efficient performance scaling with the rapidly increasing size of both data and model parameters. Despite its central role, the theoretical understanding of Attention, especially in the nonlinear setting, is progressing at a more modest pace. This paper provides a precise characterization of the interpolation error for a nonlinear Attention, in the high-dimensional regime where the number of input tokens and the embedding dimension are both large and comparable. Under a signal-plus-noise data model and for fixed Attention weights, we derive explicit (limiting) expressions for the mean-squared interpolation error. Leveraging recent advances in random matrix theory, we show that nonlinear Attention generally incurs a larger interpolation error than linear regression on random inputs. However, this gap vanishes, and can even be reversed, when the input contains a structured signal, particularly if the Attention weights align with the signal direction. Our theoretical insights are supported by numerical experiments.

Paper Structure

This paper contains 42 sections, 14 theorems, 133 equations, 7 figures.

Key Result

Lemma 1

Let Assumptions ass:low-rank-weights--ass:high-dim hold. Then, the Attention kernel $\mathbf{K}_\mathbf{X} = f( \mathbf{X}^\top \mathbf{W}_K^\top \mathbf{W}_Q \mathbf{X}/\sqrt p)/\sqrt p$ defined in eq:def_A_X satisfies with probability approaching one as $n,p \to \infty$. Here, $a_1$ is the first Hermite coefficient of $f$ (see ass:nonlinear), $\mathbf{K}_N \equiv f(\mathbf{Z}^\top \mathbf{Z}/\s

Figures (7)

  • Figure 1: Empirical interpolation error $E$ ( red) of nonlinear Attention versus its high-dimensional equivalent $\bar{E}$ ( blue) from \ref{['theo:high_interpolation']}, and the theoretical interpolation error of linear regression ( green) from \ref{['prop:high_interpolation_error_RR']}, with $f(t) = \tanh(t)$. \ref{['subfig:E_gamma']}: As a function of regularization strength $\gamma$, under null model with $\boldsymbol{\mu} = \mathbf{w}_K = \mathbf{w}_Q = \mathbf{0}$, $p = 4\,096$, and $n = 1\,024$. \ref{['subfig:E_p']}: As a function of embedding dimension $p$, under null model with $n = 4\,096$, $\gamma = 10^{-2}$. \ref{['subfig:E_SNR']}: As a function of SNR $\| \boldsymbol{\mu} \|^2$, with $p = 512, n = 2\,048,\gamma = 10^{-2}$, and $\mathbf{w}_K = \mathbf{w}_Q = \boldsymbol{\mu}$.
  • Figure 2: Interpolation behavior under alternative nonlinearities and pretrained Attention weights. \ref{['subfig:E4_1']}: Theoretical interpolation errors for $f(t) =\tanh(t)$ ( blue) versus $f(t) = \max(-5,\min(5,t))$ ( purple) and that of linear regression ( green) in the over-determined regime, as a function of SNR, with $p = 512, n = 2\,048,\gamma = 1$, and $\mathbf{w}_K = \mathbf{w}_Q = \boldsymbol{\mu}$. \ref{['subfig:E4_2']}: Interpolation error of Softmax (cyan) and entry-wise tanh ( blue), truncated exponential ($f(t) = \min(5, \exp(t))$ in red) Attention. Theoretical predictions under \ref{['ass:low-rank-weights']} in solid lines and predictions with key/query weights extracted from a pretrained GPT-2 model (the detailed setup is provided in \ref{['sec:SM_additional_nums']}) in dotted lines, as a function of $\gamma$, for $p = 2\,048, n = 512$, and $\mathbf{w}_K = \mathbf{w}_Q = \boldsymbol{\mu} = \mathbf{0}$.
  • Figure 3: Effect of linear component in Attention interpolation. \ref{['subfig:E_a1']}: Empirical ( red) and theoretical (cyan) interpolation error for $f(t) = \max(-5,\min(5,rt + \sqrt{1 - r^2} (t^3 - 3t)/\sqrt6))$ as a function of the Hermite coefficient $a_1 \approx r$ for $p = n = 4\,096, \gamma = 1$, and $\| \boldsymbol{\mu} \|^2 = 1$. \ref{['subfig:E_cos_p']}: Empirical ( red) and theoretical (cyan) for $f(t) = \cos(t)$, versus the theoretical error of $f(t) = \tanh(t)$ ( blue) and the theoretical error of $f(t) = \max(-5,\min(5,t))$ ( purple), as a function of the embedding dimension $p$, for sample size $n = 4\,096, \gamma = 1$, and $\| \boldsymbol{\mu} \|^2 = 1$. \ref{['subfig:E_cos_SNR']}: Empirical ( red) and theoretical (cyan) for $f(t) = \cos(t)$, versus the theoretical error of $f(t) = \tanh(t)$ ( blue) and the theoretical error of $f(t) = \max(-5,\min(5,t))$ ( purple), as a function of the SNR $\| \boldsymbol{\mu} \|^2$, for $p = 512, n = 2\,048,\gamma = 10^{-2}$, and $\mathbf{w}_K = \mathbf{w}_Q = \boldsymbol{\mu}$.
  • Figure 4: Theoretical interpolation errors of $\tanh$ ( blue) and truncated linear (with $f(t) = \max(-5,\min(5,t))$ in purple) Transformer, for key/query weights aligned with the signal direction in solid lines: $\mathbf{w}_K = \mathbf{w}_Q = \boldsymbol{\mu}_{\rm base} \sim {\mathcal{N}}(\mathbf{0}, \mathbf{I}_p/p)$ and $\boldsymbol{\mu} \propto \boldsymbol{\mu}_{\rm base}$; versus the case where both weights orthogonal to the signal in dotted lines: $\mathbf{w}_K \perp \boldsymbol{\mu}_{\rm base}, \mathbf{w}_Q \perp \boldsymbol{\mu}_{\rm base}$, $\mathbf{w}_K \perp \mathbf{w}_Q$ and $\boldsymbol{\mu} \propto \boldsymbol{\mu}_{\rm base}$; for regularization strength $\gamma = 1$.
  • Figure 5: Theoretical interpolation errors for $f(t) =\tanh(t)$ ( blue) from \ref{['theo:high_interpolation']} versus that of linear regression ( green) from \ref{['prop:high_interpolation_error_RR']}, as a function of SNR, for different dimension ratio $p/n$, synthetic data drawn from the Gaussian signal-plus-noise model as in \ref{['def:signal_plus_noise']} with $\mathbf{w}_K = \mathbf{w}_Q = \boldsymbol{\mu}_{\rm base} \sim \mathcal{N}(\mathbf{0},\mathbf{1}_p/p)$, $\boldsymbol{\mu} \propto \boldsymbol{\mu}_{\rm base}$, and $\gamma = 1$.
  • ...and 2 more figures

Theorems & Definitions (34)

  • Definition 1: Entry-wise Attention
  • Remark 1: Softmax Attention
  • Definition 2: Signal-plus-noise model
  • Remark 2: Beyond the signal-plus-noise model in \ref{['def:signal_plus_noise']}
  • Definition 3: Interpolation error of Attention
  • Definition 4: Deterministic Equivalent, couillet2022RMT4ML
  • Lemma 1: High-dimensional linearization of Attention kernel
  • Proposition 1: Deterministic Equivalent for noise-only nonlinear Attention
  • Theorem 1: High-dimensional characterization of Attention interpolation error
  • Definition 5: Linear regression and its interpolation error
  • ...and 24 more