Table of Contents
Fetching ...

Implicit Bias and Fast Convergence Rates for Self-attention

Bhavya Vasudeva, Puneesh Deora, Christos Thrampoulidis

TL;DR

This work investigates the optimization behavior of a single-layer self-attention model with a linear decoder in binary classification, focusing on the implicit bias of gradient-based methods. It establishes data-driven conditions under which gradient descent with adaptive steps—normalized gradient descent (NGD) and Polyak step-size (PS)—globally converges in direction to the max-margin solution W_{mm}, and it derives finite-time rates alongside a principled sparsification of the attention map. The analysis extends to joint optimization of the attention weights and the decoder, showing global convergence for W and margin-based convergence for the decoder with exp(-t^{1/3}) loss decay under Gaussian-like data. Empirically, NGD and PS consistently outperform standard gradient descent on synthetic and language datasets, and the results illuminate the link between implicit bias phenomena in self-attention and those in linear logistic regression, despite the non-convex softmax structure. Overall, the paper advances a rigorous comprehension of optimization and implicit bias in transformer components and demonstrates practical benefits of adaptive steps in self-attention training.

Abstract

We study the fundamental optimization principles of self-attention, the defining mechanism of transformers, by analyzing the implicit bias of gradient-based optimizers in training a self-attention layer with a linear decoder in binary classification. Building on prior studies in linear logistic regression, recent findings demonstrate that the key-query matrix $W_t$ from gradient-descent (GD) converges in direction towards $W_{mm}$, which maximizes the margin between optimal and non-optimal tokens across sequences. However, this convergence is local, dependent on initial conditions, only holds asymptotically as the number of iterations increases, and leaves questions about the potential benefits of adaptive step-size rules unaddressed. To bridge this gap, we first establish scenarios for which convergence is provably \emph{global}. We then analyze two adaptive step-size strategies: normalized GD and Polyak step-size, demonstrating \emph{finite-time} convergence rates for $W_t$ to $W_{mm}$, and quantifying the sparsification rate of the attention map. These findings not only show that these strategies can accelerate parameter convergence over standard GD in a non-convex setting but also deepen the understanding of the implicit bias in self-attention, linking it more closely to the phenomena observed in linear logistic regression despite its intricate non-convex nature.

Implicit Bias and Fast Convergence Rates for Self-attention

TL;DR

This work investigates the optimization behavior of a single-layer self-attention model with a linear decoder in binary classification, focusing on the implicit bias of gradient-based methods. It establishes data-driven conditions under which gradient descent with adaptive steps—normalized gradient descent (NGD) and Polyak step-size (PS)—globally converges in direction to the max-margin solution W_{mm}, and it derives finite-time rates alongside a principled sparsification of the attention map. The analysis extends to joint optimization of the attention weights and the decoder, showing global convergence for W and margin-based convergence for the decoder with exp(-t^{1/3}) loss decay under Gaussian-like data. Empirically, NGD and PS consistently outperform standard gradient descent on synthetic and language datasets, and the results illuminate the link between implicit bias phenomena in self-attention and those in linear logistic regression, despite the non-convex softmax structure. Overall, the paper advances a rigorous comprehension of optimization and implicit bias in transformer components and demonstrates practical benefits of adaptive steps in self-attention training.

Abstract

We study the fundamental optimization principles of self-attention, the defining mechanism of transformers, by analyzing the implicit bias of gradient-based optimizers in training a self-attention layer with a linear decoder in binary classification. Building on prior studies in linear logistic regression, recent findings demonstrate that the key-query matrix from gradient-descent (GD) converges in direction towards , which maximizes the margin between optimal and non-optimal tokens across sequences. However, this convergence is local, dependent on initial conditions, only holds asymptotically as the number of iterations increases, and leaves questions about the potential benefits of adaptive step-size rules unaddressed. To bridge this gap, we first establish scenarios for which convergence is provably \emph{global}. We then analyze two adaptive step-size strategies: normalized GD and Polyak step-size, demonstrating \emph{finite-time} convergence rates for to , and quantifying the sparsification rate of the attention map. These findings not only show that these strategies can accelerate parameter convergence over standard GD in a non-convex setting but also deepen the understanding of the implicit bias in self-attention, linking it more closely to the phenomena observed in linear logistic regression despite its intricate non-convex nature.
Paper Structure (49 sections, 23 theorems, 181 equations, 10 figures, 1 table)

This paper contains 49 sections, 23 theorems, 181 equations, 10 figures, 1 table.

Key Result

Theorem 1

Under small initialization scale (Lem. lem:init in the App.) and Ass. ass:data-relaxed-ass:data-score, using the NGD updates in Eq. eq:ngd-update with $\eta\!=\!\widetilde{\mathcal{O}}(B^{-2})$, it holds for any $t\!\geq\! t_0\!=\!\operatorname{poly}\left(\eta,B,\Lambda,\mathrm{T},n\Upsilon\right)$, where $C\!:=C(\eta,B,\Lambda,t_0)\!=\!\operatorname{poly}\left(\eta,B,\Lambda,t_0\right)$. In parti

Figures (10)

  • Figure 1: Comparison of train and test dynamics of various optimizers---SGD, stochastic normalized GD (SNGD), stochastic Polyak step (SPS), and Adam---while fine-tuning a pre-trained BERT model on the MNLI dataset; see App. \ref{['app:expt-settings']} for details. SNGD and SPS, employing adaptive step-size rules, demonstrate significantly faster training, closely resembling the performance of Adam. Motivated in part by this observation, our work establishes fast convergence rates for NGD and PS for single-layer self-attention.
  • Figure 2: Training dynamics of a single-head self-attention model (Eq. \ref{['eq:model']}) when optimizing only $\bm{W}$ on synthetic data with nearly orthogonal tokens (Example \ref{['ex:orth']} with $\sigma=0$). The observed softmax score saturation, norm growth of $\bm{W}$ and directional alignment with $\bm{W}_{\text{mm}}$ closely match with our theoretical results (Lemma \ref{['lem:softmax-sat-rate']}, Eq. \ref{['eq:w-norm-org']} and Theorem \ref{['th:conv']}, respectively).
  • Figure 3: Training dynamics of a self-attention model (Eq. \ref{['eq:model']}) with data generated using model \ref{['data_model']}.
  • Figure 4: Comparison of train and test dynamics of various optimizers---SGD, SNGD, SPS, and Adam---while fine-tuning a pre-trained BERT model on the CivilComments dataset.
  • Figure 5: Training dynamics when optimizing only $\bm{W}$ on synthetic data with antipodal $\texttt{opt}\xspace$ tokens.
  • ...and 5 more figures

Theorems & Definitions (46)

  • Definition 1: Token scores and Optimality
  • Example 1
  • Theorem 1: IB rate
  • Lemma 1: Softmax score rate
  • Theorem 2: Train loss convergence
  • Theorem 3: IB rate of $\bm{W}$
  • Theorem 4: IB rate of $\bm{u}$
  • Lemma 2
  • proof
  • Lemma 3
  • ...and 36 more