Table of Contents
Fetching ...

Global Convergence in Training Large-Scale Transformers

Cheng Gao, Yuan Cao, Zihao Li, Yihan He, Mengdi Wang, Han Liu, Jason Matthew Klusowski, Jianqing Fan

TL;DR

This paper rigorously analyzes the convergence properties of gradient flow in training Transformers with weight decay regularization, and constructs the mean-field limit of large-scale Transformers, demonstrating that the gradient flow reaches a global minimum consistent with the PDE solution when the weight decay regularization parameter is sufficiently small.

Abstract

Despite the widespread success of Transformers across various domains, their optimization guarantees in large-scale model settings are not well-understood. This paper rigorously analyzes the convergence properties of gradient flow in training Transformers with weight decay regularization. First, we construct the mean-field limit of large-scale Transformers, showing that as the model width and depth go to infinity, gradient flow converges to the Wasserstein gradient flow, which is represented by a partial differential equation. Then, we demonstrate that the gradient flow reaches a global minimum consistent with the PDE solution when the weight decay regularization parameter is sufficiently small. Our analysis is based on a series of novel mean-field techniques that adapt to Transformers. Compared with existing tools for deep networks (Lu et al., 2020) that demand homogeneity and global Lipschitz smoothness, we utilize a refined analysis assuming only $\textit{partial homogeneity}$ and $\textit{local Lipschitz smoothness}$. These new techniques may be of independent interest.

Global Convergence in Training Large-Scale Transformers

TL;DR

This paper rigorously analyzes the convergence properties of gradient flow in training Transformers with weight decay regularization, and constructs the mean-field limit of large-scale Transformers, demonstrating that the gradient flow reaches a global minimum consistent with the PDE solution when the weight decay regularization parameter is sufficiently small.

Abstract

Despite the widespread success of Transformers across various domains, their optimization guarantees in large-scale model settings are not well-understood. This paper rigorously analyzes the convergence properties of gradient flow in training Transformers with weight decay regularization. First, we construct the mean-field limit of large-scale Transformers, showing that as the model width and depth go to infinity, gradient flow converges to the Wasserstein gradient flow, which is represented by a partial differential equation. Then, we demonstrate that the gradient flow reaches a global minimum consistent with the PDE solution when the weight decay regularization parameter is sufficiently small. Our analysis is based on a series of novel mean-field techniques that adapt to Transformers. Compared with existing tools for deep networks (Lu et al., 2020) that demand homogeneity and global Lipschitz smoothness, we utilize a refined analysis assuming only and . These new techniques may be of independent interest.

Paper Structure

This paper contains 59 sections, 24 theorems, 282 equations, 2 figures.

Key Result

Proposition 3.1

Under Assumptions ass:dataBound and ass:growth, for any pair $\rho,\nu\in\pazocal{P}^2$ that have bounded supports, we have where $\frac{\delta Q}{\delta \rho}$ is defined in eq:gradient2rho, and $\langle \frac{\delta Q}{\delta \rho}, \nu - \rho \rangle=\int_0^1\int_{(\theta,w)} \frac{\delta Q}{\delta \rho}\cdot(\nu-\rho) d(\theta,w)dt\in\mathbb{R}$.

Figures (2)

  • Figure 1: Training loss and training accuracy of Vision Transformers with different numbers of heads. (a) gives the curves of training loss, while (b) gives the curves of training accuracy.
  • Figure 2: Training loss and training accuracy of Vision Transformers with different depths. (a) gives the curves of training loss, while (b) gives the curves of training accuracy.

Theorems & Definitions (49)

  • Proposition 3.1: Functional derivative to $\rho$
  • Proposition 3.2: Existence and uniqueness of Wasserstein gradient flow
  • Proposition 3.3: Global minimum approximation of discretization
  • Theorem 3.1: Gradient flow approximation of discretization
  • Theorem 4.1: Global convergence up to $\lambda$
  • Corollary 4.1
  • Remark 1
  • Proposition C.1: Existence and uniqueness of Transformer ODE
  • proof : Proof of Proposition \ref{['prop:odeSol']}
  • Lemma C.1: Continuous Transformer output bound
  • ...and 39 more