Training Dynamics of Softmax Self-Attention: Fast Global Convergence via Preconditioning
Gautam Goel, Mahdi Soltanolkotabi, Peter Bartlett
TL;DR
The training dynamics of gradient descent in a softmax self-attention layer trained to perform linear regression are studied and it is shown that a simple first-order optimization algorithm can converge to the globally optimal self-attention parameters at a geometric rate.
Abstract
We study the training dynamics of gradient descent in a softmax self-attention layer trained to perform linear regression and show that a simple first-order optimization algorithm can converge to the globally optimal self-attention parameters at a geometric rate. Our analysis proceeds in two steps. First, we show that in the infinite-data limit the regression problem solved by the self-attention layer is equivalent to a nonconvex matrix factorization problem. Second, we exploit this connection to design a novel "structure-aware" variant of gradient descent which efficiently optimizes the original finite-data regression objective. Our optimization algorithm features several innovations over standard gradient descent, including a preconditioner and regularizer which help avoid spurious stationary points, and a data-dependent spectral initialization of parameters which lie near the manifold of global minima with high probability.
