Table of Contents
Fetching ...

Local to Global: Learning Dynamics and Effect of Initialization for Transformers

Ashok Vardhan Makkuva, Marco Bondaschi, Chanakya Ekbote, Adway Girish, Alliot Nagle, Hyeji Kim, Michael Gastpar

TL;DR

It is proved that transformer parameters trained on next-token prediction loss can either converge to global or local minima, contingent on the initialization and the Markovian data properties, and this is the first result of its kind highlighting the role of initialization.

Abstract

In recent years, transformer-based models have revolutionized deep learning, particularly in sequence modeling. To better understand this phenomenon, there is a growing interest in using Markov input processes to study transformers. However, our current understanding in this regard remains limited with many fundamental questions about how transformers learn Markov chains still unanswered. In this paper, we address this by focusing on first-order Markov chains and single-layer transformers, providing a comprehensive characterization of the learning dynamics in this context. Specifically, we prove that transformer parameters trained on next-token prediction loss can either converge to global or local minima, contingent on the initialization and the Markovian data properties, and we characterize the precise conditions under which this occurs. To the best of our knowledge, this is the first result of its kind highlighting the role of initialization. We further demonstrate that our theoretical findings are corroborated by empirical evidence. Based on these insights, we provide guidelines for the initialization of transformer parameters and demonstrate their effectiveness. Finally, we outline several open problems in this arena. Code is available at: https://github.com/Bond1995/Markov.

Local to Global: Learning Dynamics and Effect of Initialization for Transformers

TL;DR

It is proved that transformer parameters trained on next-token prediction loss can either converge to global or local minima, contingent on the initialization and the Markovian data properties, and this is the first result of its kind highlighting the role of initialization.

Abstract

In recent years, transformer-based models have revolutionized deep learning, particularly in sequence modeling. To better understand this phenomenon, there is a growing interest in using Markov input processes to study transformers. However, our current understanding in this regard remains limited with many fundamental questions about how transformers learn Markov chains still unanswered. In this paper, we address this by focusing on first-order Markov chains and single-layer transformers, providing a comprehensive characterization of the learning dynamics in this context. Specifically, we prove that transformer parameters trained on next-token prediction loss can either converge to global or local minima, contingent on the initialization and the Markovian data properties, and we characterize the precise conditions under which this occurs. To the best of our knowledge, this is the first result of its kind highlighting the role of initialization. We further demonstrate that our theoretical findings are corroborated by empirical evidence. Based on these insights, we provide guidelines for the initialization of transformer parameters and demonstrate their effectiveness. Finally, we outline several open problems in this arena. Code is available at: https://github.com/Bond1995/Markov.
Paper Structure (52 sections, 27 theorems, 196 equations, 6 figures, 2 tables)

This paper contains 52 sections, 27 theorems, 196 equations, 6 figures, 2 tables.

Key Result

Theorem 1

Let the input sequence be $\{x_n\}_{n=1}^N \sim (\boldsymbol{\pi}, \boldsymbol{P})$, the transformer parameters $\boldsymbol{\theta} = (e,w) \in \mathbb{R}^2$, and the next-token prediction loss $L(\cdot)$ be as in eq:loss_cannoli. Then for any $(p,q) \in (0,1)^2$ with $p+q \neq 1$ and $N \in \mathb Thus the set of all critical points is In addition, for any $\boldsymbol{\theta}_\star \in \boldsy

Figures (6)

  • Figure 1: Gradient flow dynamics and initialization effect for single-layer transformers. $(p,q)$ are Markov switching probabilities, and $(e,w)$ are the embedding and weight parameters (sec:prob). (a), (c): The flow is aligned along energy contour lines, converging to local or global optima. (b), (d): $\mathcal{I}_\star$ is the basin of convergence for global minima, $\mathcal{I}_{\mathrm{min}}$ for the local minima, and yellow asymptotes for the saddle point. Notice the contrasting behavior for Gaussian initialization around origin for $p+q \lessgtr 1$.
  • Figure 2: Gradient flow dynamics for the canonical parameters $\boldsymbol{\theta}=(e,w,a) \in \mathbb{R}^3$ with the attention scalar $a$. Notice the contrasting behavior for Gaussian initialization around origin for $p+q$ smaller and greater than one. For an enhanced view of the flow near the origin, please refer to fig:pq_grad_flow_attn_near_origin.
  • Figure 3: Evolution of parameters $\boldsymbol{W}_1$ and $\boldsymbol{W}_V$ across iterations, starting from a standard Gaussian initialization. At convergence, all the parameter matrices are approximately rank-one.
  • Figure 4: Comparison between the average loss curve for the standard gaussian initialization around $0$ and our initialization, for $p=0.5$ and $q=0.8$. Starting from the standard initialization, the model converges to a local minimum corresponding to the unigram model. With our initialization, it converges to the global minimum corresponding to the bigram model.
  • Figure 5: Gradient flow dynamics in $\mathbb{R}^3$, near the origin, for the transformer parameters with attention scalar $a$ (sec:learn_dyna_attn). The local minima are repellors for $p+q<1$, while attracting for $p+q>1$.
  • ...and 1 more figures

Theorems & Definitions (59)

  • Theorem 1: All critical points
  • proof
  • Lemma 1: Constant energy along the flow
  • Theorem 2: GF dynamics for $p+q>1$
  • proof : Proof sketch
  • Theorem 3: GF dynamics with attention
  • proof
  • Theorem 4: Global minimum
  • Theorem 5: Bad local minimum
  • Theorem 6: All critical points
  • ...and 49 more