Implicit Bias and Fast Convergence Rates for Self-attention
Bhavya Vasudeva, Puneesh Deora, Christos Thrampoulidis
TL;DR
This work investigates the optimization behavior of a single-layer self-attention model with a linear decoder in binary classification, focusing on the implicit bias of gradient-based methods. It establishes data-driven conditions under which gradient descent with adaptive steps—normalized gradient descent (NGD) and Polyak step-size (PS)—globally converges in direction to the max-margin solution W_{mm}, and it derives finite-time rates alongside a principled sparsification of the attention map. The analysis extends to joint optimization of the attention weights and the decoder, showing global convergence for W and margin-based convergence for the decoder with exp(-t^{1/3}) loss decay under Gaussian-like data. Empirically, NGD and PS consistently outperform standard gradient descent on synthetic and language datasets, and the results illuminate the link between implicit bias phenomena in self-attention and those in linear logistic regression, despite the non-convex softmax structure. Overall, the paper advances a rigorous comprehension of optimization and implicit bias in transformer components and demonstrates practical benefits of adaptive steps in self-attention training.
Abstract
We study the fundamental optimization principles of self-attention, the defining mechanism of transformers, by analyzing the implicit bias of gradient-based optimizers in training a self-attention layer with a linear decoder in binary classification. Building on prior studies in linear logistic regression, recent findings demonstrate that the key-query matrix $W_t$ from gradient-descent (GD) converges in direction towards $W_{mm}$, which maximizes the margin between optimal and non-optimal tokens across sequences. However, this convergence is local, dependent on initial conditions, only holds asymptotically as the number of iterations increases, and leaves questions about the potential benefits of adaptive step-size rules unaddressed. To bridge this gap, we first establish scenarios for which convergence is provably \emph{global}. We then analyze two adaptive step-size strategies: normalized GD and Polyak step-size, demonstrating \emph{finite-time} convergence rates for $W_t$ to $W_{mm}$, and quantifying the sparsification rate of the attention map. These findings not only show that these strategies can accelerate parameter convergence over standard GD in a non-convex setting but also deepen the understanding of the implicit bias in self-attention, linking it more closely to the phenomena observed in linear logistic regression despite its intricate non-convex nature.
