Table of Contents
Fetching ...

On the Optimization and Generalization of Multi-head Attention

Puneesh Deora, Rouzbeh Ghaderi, Hossein Taheri, Christos Thrampoulidis

TL;DR

This work analyzes the finite-time optimization and generalization of gradient-descent training for multi-head self-attention (MHA) in a binary classification setting. It derives gradient and Hessian bounds for softmax attention, establishes self-bounded weak convexity of the empirical risk, and proves training and generalization guarantees under realizability, with performance scaling in the number of heads and initialization quality. The authors instantiate the theory on a tokenized-mixture data model, showing NTK separability after a single random initialization step and deriving margins that govern convergence and generalization, while discussing margins that may be unattainable under certain initialization regimes. Overall, the paper connects attention mechanisms to overparameterized NN theory, providing finite-time bounds and insights into the role of overparameterization in optimization and generalization for transformers.

Abstract

The training and generalization dynamics of the Transformer's core mechanism, namely the Attention mechanism, remain under-explored. Besides, existing analyses primarily focus on single-head attention. Inspired by the demonstrated benefits of overparameterization when training fully-connected networks, we investigate the potential optimization and generalization advantages of using multiple attention heads. Towards this goal, we derive convergence and generalization guarantees for gradient-descent training of a single-layer multi-head self-attention model, under a suitable realizability condition on the data. We then establish primitive conditions on the initialization that ensure realizability holds. Finally, we demonstrate that these conditions are satisfied for a simple tokenized-mixture model. We expect the analysis can be extended to various data-model and architecture variations.

On the Optimization and Generalization of Multi-head Attention

TL;DR

This work analyzes the finite-time optimization and generalization of gradient-descent training for multi-head self-attention (MHA) in a binary classification setting. It derives gradient and Hessian bounds for softmax attention, establishes self-bounded weak convexity of the empirical risk, and proves training and generalization guarantees under realizability, with performance scaling in the number of heads and initialization quality. The authors instantiate the theory on a tokenized-mixture data model, showing NTK separability after a single random initialization step and deriving margins that govern convergence and generalization, while discussing margins that may be unattainable under certain initialization regimes. Overall, the paper connects attention mechanisms to overparameterized NN theory, providing finite-time bounds and insights into the role of overparameterization in optimization and generalization for transformers.

Abstract

The training and generalization dynamics of the Transformer's core mechanism, namely the Attention mechanism, remain under-explored. Besides, existing analyses primarily focus on single-head attention. Inspired by the demonstrated benefits of overparameterization when training fully-connected networks, we investigate the potential optimization and generalization advantages of using multiple attention heads. Towards this goal, we derive convergence and generalization guarantees for gradient-descent training of a single-layer multi-head self-attention model, under a suitable realizability condition on the data. We then establish primitive conditions on the initialization that ensure realizability holds. Finally, we demonstrate that these conditions are satisfied for a simple tokenized-mixture model. We expect the analysis can be extended to various data-model and architecture variations.
Paper Structure (46 sections, 34 theorems, 234 equations, 6 figures)

This paper contains 46 sections, 34 theorems, 234 equations, 6 figures.

Key Result

Lemma 1

For all $\bm{a}\in\mathbb{R}^T$, $\bm{b},\bm{c}\in\mathbb{R}^d$ the model's gradients satisfy: $\bullet$$\quad$ and $\qquad $$\bullet$$\quad$ and $$

Figures (6)

  • Figure 1: Training proof schema.
  • Figure 2: For data model \ref{['model1']}. Effect of number of heads $H$ on convergence rates when trained with GD for constant step-size $\eta = \mathcal{O}\left(1\right)$. The average $\lVert\cdot\rVert$ illustrates $1/H$ and $1/\sqrt{H}$ average for $\bm{W}$ and $\bm{U}$ across heads, respectively. Attn-score denotes the softmax scores for the relevant tokens averaged across all train samples and heads. The average $\lVert\bm{W}\rVert$ indicates the saturation of softmax scores and consequently the token-selection (attn-score), and the average $\lVert\bm{U}\rVert$ controls the loss behaviour. Results demonstrate that overparameterization slows down GD with constant step-size. The circled area shows a $\mathcal{O}(1/t)$ trend similar to what our training and generalization bounds predict.
  • Figure 3: For data model \ref{['model1']}. Effect of number of heads $H$ on convergence rates when (left) trained with GD when scaling step-size as $\eta = \mathcal{O}\left(\sqrt{H}\right)$; (right) trained with Adam with constant step-size $\eta = \mathcal{O}\left(1\right)$. Quantities plotted are same as in Figure \ref{['fig:context-GD']}. Results demonstrate that overparameterization speeds-up with train and test loss convergence in both the scenarios.
  • Figure 4: For data model \ref{['model planted']}. Effect of number of heads $H$ on convergence rates when trained with GD when scaling step-size as $\eta = \mathcal{O}\left(\sqrt{H}\right)$. See Figure \ref{['fig:context-GD']} caption for to get more context on average $\lVert\cdot\rVert$. Alignment of $\bm{W}$ with the planted-head $\widetilde{\bm{W}}^\star$ at any iteration $k$ is given by $\frac{\inp{\widetilde{\bm{W}}_k}{\widetilde{\bm{W}}^\star}}{\lVert\widetilde{\bm{W}}_k\rVert\lVert\widetilde{\bm{W}}^\star\rVert}$, where $\widetilde{\bm{W}}^\star:=\operatorname{concat}\xspace(\{\bm{W}^\star\}_{h\in[H]})$ contains $\bm{W}^\star$ repeated $H$ times. Alignment between $\widetilde{\bm{U}}$ and $\widetilde{\bm{U}}^\star$ is computed similarly.
  • Figure 5: For data model \ref{['model planted']}. Effect of number of heads $H$ on convergence rates when trained with (left) GD + momentum where step-size scales as $\eta = \mathcal{O}\left(\sqrt{H}\right)$; (right) Adam with constant step-size $\eta = \mathcal{O}\left(1\right)$. Quantities plotted are same as in Figure \ref{['fig:planted-GD']}. Results demonstrate that overparameterization speeds up convergence in both scenarios.
  • ...and 1 more figures

Theorems & Definitions (59)

  • Lemma 1: Gradient/Hessian formulas
  • Proposition 1: Model Gradient/Hessian bounds
  • Corollary 1: Loss properties
  • Theorem 1: Training loss
  • Theorem 2: Generalization loss
  • Definition 1: Good initialization
  • Corollary 2: General bounds under good initialization
  • Remark 1
  • Lemma 2: First phase
  • Proposition 2
  • ...and 49 more