On the Optimization and Generalization of Multi-head Attention
Puneesh Deora, Rouzbeh Ghaderi, Hossein Taheri, Christos Thrampoulidis
TL;DR
This work analyzes the finite-time optimization and generalization of gradient-descent training for multi-head self-attention (MHA) in a binary classification setting. It derives gradient and Hessian bounds for softmax attention, establishes self-bounded weak convexity of the empirical risk, and proves training and generalization guarantees under realizability, with performance scaling in the number of heads and initialization quality. The authors instantiate the theory on a tokenized-mixture data model, showing NTK separability after a single random initialization step and deriving margins that govern convergence and generalization, while discussing margins that may be unattainable under certain initialization regimes. Overall, the paper connects attention mechanisms to overparameterized NN theory, providing finite-time bounds and insights into the role of overparameterization in optimization and generalization for transformers.
Abstract
The training and generalization dynamics of the Transformer's core mechanism, namely the Attention mechanism, remain under-explored. Besides, existing analyses primarily focus on single-head attention. Inspired by the demonstrated benefits of overparameterization when training fully-connected networks, we investigate the potential optimization and generalization advantages of using multiple attention heads. Towards this goal, we derive convergence and generalization guarantees for gradient-descent training of a single-layer multi-head self-attention model, under a suitable realizability condition on the data. We then establish primitive conditions on the initialization that ensure realizability holds. Finally, we demonstrate that these conditions are satisfied for a simple tokenized-mixture model. We expect the analysis can be extended to various data-model and architecture variations.
