Table of Contents
Fetching ...

On the Optimization and Generalization of Two-layer Transformers with Sign Gradient Descent

Bingrui Li, Wei Huang, Andi Han, Zhanpeng Zhou, Taiji Suzuki, Jun Zhu, Jianfei Chen

TL;DR

This work analyzes Sign Gradient Descent as a proxy for Adam to understand optimization in a two-layer transformer with a softmax attention layer and trainable query-key parameterization on a binary, linearly separable dataset containing signal $\boldsymbol{\mu}$ and sparse noise. It identifies four distinct training stages, proves fast convergence of training loss while maintaining a constant test loss due to noise memorization in the attention mechanism, and shows Adam reproduces similar patterns. The study combines a feature-learning framework with a sparse data model to derive precise dynamics and convergence results, supported by experiments on synthetic data and real-world MNIST-like tasks. The results suggest that both SignGD and Adam can require higher data quality than simple gradient descent to generalize well, highlighting limits of these optimizers in noisy transformer settings and providing theoretical insight into Adam’s practical behavior.

Abstract

The Adam optimizer is widely used for transformer optimization in practice, which makes understanding the underlying optimization mechanisms an important problem. However, due to the Adam's complexity, theoretical analysis of how it optimizes transformers remains a challenging task. Fortunately, Sign Gradient Descent (SignGD) serves as an effective surrogate for Adam. Despite its simplicity, theoretical understanding of how SignGD optimizes transformers still lags behind. In this work, we study how SignGD optimizes a two-layer transformer -- consisting of a softmax attention layer with trainable query-key parameterization followed by a linear layer -- on a linearly separable noisy dataset. We identify four stages in the training dynamics, each exhibiting intriguing behaviors. Based on the training dynamics, we prove the fast convergence but poor generalization of the learned transformer on the noisy dataset. We also show that Adam behaves similarly to SignGD in terms of both optimization and generalization in this setting. Additionally, we find that the poor generalization of SignGD is not solely due to data noise, suggesting that both SignGD and Adam requires high-quality data for real-world tasks. Finally, experiments on synthetic and real-world datasets empirically support our theoretical results.

On the Optimization and Generalization of Two-layer Transformers with Sign Gradient Descent

TL;DR

This work analyzes Sign Gradient Descent as a proxy for Adam to understand optimization in a two-layer transformer with a softmax attention layer and trainable query-key parameterization on a binary, linearly separable dataset containing signal and sparse noise. It identifies four distinct training stages, proves fast convergence of training loss while maintaining a constant test loss due to noise memorization in the attention mechanism, and shows Adam reproduces similar patterns. The study combines a feature-learning framework with a sparse data model to derive precise dynamics and convergence results, supported by experiments on synthetic data and real-world MNIST-like tasks. The results suggest that both SignGD and Adam can require higher data quality than simple gradient descent to generalize well, highlighting limits of these optimizers in noisy transformer settings and providing theoretical insight into Adam’s practical behavior.

Abstract

The Adam optimizer is widely used for transformer optimization in practice, which makes understanding the underlying optimization mechanisms an important problem. However, due to the Adam's complexity, theoretical analysis of how it optimizes transformers remains a challenging task. Fortunately, Sign Gradient Descent (SignGD) serves as an effective surrogate for Adam. Despite its simplicity, theoretical understanding of how SignGD optimizes transformers still lags behind. In this work, we study how SignGD optimizes a two-layer transformer -- consisting of a softmax attention layer with trainable query-key parameterization followed by a linear layer -- on a linearly separable noisy dataset. We identify four stages in the training dynamics, each exhibiting intriguing behaviors. Based on the training dynamics, we prove the fast convergence but poor generalization of the learned transformer on the noisy dataset. We also show that Adam behaves similarly to SignGD in terms of both optimization and generalization in this setting. Additionally, we find that the poor generalization of SignGD is not solely due to data noise, suggesting that both SignGD and Adam requires high-quality data for real-world tasks. Finally, experiments on synthetic and real-world datasets empirically support our theoretical results.
Paper Structure (55 sections, 53 theorems, 258 equations, 19 figures, 11 tables)

This paper contains 55 sections, 53 theorems, 258 equations, 19 figures, 11 tables.

Key Result

Theorem 3.2

For any $\epsilon > 0$, under Condition cond:main_condition, with probability at least $1-n^{-1/3}$, there exists $T = O(\log(\epsilon^{-1})\eta^{-1}\sigma_p^{-1}s^{-1})$ , and $T_{\mathrm{attn}} = \Tilde{O}(\eta^{-1}m_k^{-1/2}\sigma_p^{-1/2}s^{-1/2}\left\lVert\boldsymbol{\mathbf{\mu}}\right\rVert^

Figures (19)

  • Figure 1: The training dynamics of two-layer transformers with SignGD.(a) Dynamics of mean value noise and mean value signals in Stage I, and II. (b) Dynamics of key noise in Stage I, and II. We mark different key noise in different colors. ①: $\boldsymbol{\mathbf{w}}_{K,s}^{(t)} \in S_{K+, Q+}^{(0)} := \{\boldsymbol{\mathbf{w}}_{K,s}^{(t)}: \langle \boldsymbol{\mathbf{w}}_{K,s}^{(0)}, y_i\boldsymbol{\mathbf{\xi}}_i\rangle > 0, \langle \boldsymbol{\mathbf{w}}_{Q,s}^{(0)} y_i\boldsymbol{\mathbf{\xi}}_i\rangle > 0\}$. ②: $\boldsymbol{\mathbf{w}}_{K,s}^{(t)} \in S_{K-, Q-}^{(0)} := \{\boldsymbol{\mathbf{w}}_{K,s}^{(t)}: \langle \boldsymbol{\mathbf{w}}_{K,s}^{(0)}, y_i\boldsymbol{\mathbf{\xi}}_i\rangle < 0, \langle \boldsymbol{\mathbf{w}}_{Q,s}^{(0)} y_i\boldsymbol{\mathbf{\xi}}_i\rangle < 0\}$. ③: $\boldsymbol{\mathbf{w}}_{K,s}^{(t)} \in (S_{K+, Q+}^{(0)} \cup S_{K-, Q-}^{(0)} )^c$. (c) Dynamics of query noise, key noise, query signals, key signals in Stage II and III. The dotted lines represent positive (query and key) noise at $t = 40$, and the solid lines represent negative noise at the same point. (d) Dynamics of query noise, key noise, query signals, key signals in Stage III and IV. The dotted lines and solid lines have the same meanings in (c). (e) Dynamics of softmax outputs in four stages. The dynamics over the whole time horizon is provided in Fig. \ref{['fig:full_dynamics_fig1']}. An illustration explaining the behaviors of all quantities in all stages is provided in Fig. \ref{['fig:full_dynamics_illustration']}.
  • Figure 2: Comparison of SignGD with Adam and GD on synthetic and real-world datasets.(a) Dynamics of query noise, key noise, query signals, and key signals with SignGD on the synthetic dataset. (b) Dynamics of the same quantities with Adam($\beta_1=0.9$). (c) Training loss curve (log scale) on the synthetic data for different optimizers. The training loss with SignGD decays exponentially. Note that the training losses for Adam($\beta_1=0.9$), Adam($\beta_1=0.5$), Adam($\beta_1=0.0$) overlap. (d) Test loss on the noisy MNIST dataset across varying noise levels. A larger scaled SNR indicates less noise in the dataset.
  • Figure 3: Data setting (a) in Tab. \ref{['tab:experimental_settings']} with $d=2000$.(a) Key noise dynamics over $t=0$ to $t=2$. (b) Mean value noise dynamics over $t=0$ to $t=2$. While mean value noise stabilizes into a linear relationship with $t$ early, key noise remains close to initialization. (c) Softmax output dynamics over $t=0$ to $t=900$. The softmax outputs decay exponentially. At $t=150$, $s_{i,21}^{(t)}$ approaches zero, while $s_{i,11}^{(t)}$ remains close to $1/2$. (d) Dynamics of query noise, key noise, and query signals over $t=0$ to $t=900$: The dotted lines represent positive query and key noise at $t=100$, and the solid lines represent negative noise at the same point. By Stage III, the majority of positive noise makes the query signal positive through majority voting. In Stage IV, sign alignment of key noise starts at about $t=150$, coinciding with $s_{i,21}^{(t)}$ approaching zero, while delayed sign alignment of query noise begins around $t=300$, about twice as late as the key noise.
  • Figure 4: Data setting (b) in Tab. \ref{['tab:experimental_settings']} with $d=2000$.
  • Figure 5: Data setting (c) in Tab. \ref{['tab:experimental_settings']} with $d=2000$.
  • ...and 14 more figures

Theorems & Definitions (125)

  • Definition 2.1
  • Theorem 3.2
  • Lemma 4.1: Stage I
  • Lemma 4.2: Sign Alignment Between Query and Key Noise
  • Lemma 4.3: End of Stage II
  • Lemma 4.4: Stage III
  • Lemma 4.5: Exponentially Fast Decay of Noise-Signal Softmax Outputs
  • Lemma 4.6: Sign Alignment of Key Noise
  • Lemma 4.7: Delayed Sign Alignment of Query Noise
  • Lemma 4.8: End of Stage IV
  • ...and 115 more