Table of Contents
Fetching ...

Disentangling Feature Structure: A Mathematically Provable Two-Stage Training Dynamics in Transformers

Zixuan Gong, Shijia Li, Yong Liu, Jiaye Teng

TL;DR

The paper provides a rigorous, feature-level explanation for two-stage training dynamics observed in transformers under ICL prompts by modeling data with disentangled elementary ($\mathcal{P}$) and specialized ($\mathcal{Q}$) features. It analyzes a simplified one-layer transformer with a normalized ReLU self-attention mechanism and a block-diagonal weight structure, deriving finite-time convergence results that first learn $\mathcal{P}$ and then $\mathcal{Q}$ while preserving prior knowledge. The main theoretical contributions include explicit weight-growth and loss-bounds across two stages and a spectral interpretation linking attention weight eigenvalues to the learned features. Experimental validation on synthetic data, Counterfact, and HotpotQA supports the theory and reveals consistent spectral patterns: small eigenvalues encode elementary knowledge and large eigenvalues encode specialized knowledge, aligning with the proposed framework and offering insights into transformer optimization dynamics.

Abstract

Transformers may exhibit two-stage training dynamics during the real-world training process. For instance, when training GPT-2 on the Counterfact dataset, the answers progress from syntactically incorrect to syntactically correct to semantically correct. However, existing theoretical analyses hardly account for this feature-level two-stage phenomenon, which originates from the disentangled two-type features like syntax and semantics. In this paper, we theoretically demonstrate how the two-stage training dynamics potentially occur in transformers. Specifically, we analyze the feature learning dynamics induced by the aforementioned disentangled two-type feature structure, grounding our analysis in a simplified yet illustrative setting that comprises a normalized ReLU self-attention layer and structured data. Such disentanglement of feature structure is general in practice, e.g., natural languages contain syntax and semantics, and proteins contain primary and secondary structures. To our best knowledge, this is the first rigorous result regarding a feature-level two-stage optimization process in transformers. Additionally, a corollary indicates that such a two-stage process is closely related to the spectral properties of the attention weights, which accords well with our empirical findings.

Disentangling Feature Structure: A Mathematically Provable Two-Stage Training Dynamics in Transformers

TL;DR

The paper provides a rigorous, feature-level explanation for two-stage training dynamics observed in transformers under ICL prompts by modeling data with disentangled elementary () and specialized () features. It analyzes a simplified one-layer transformer with a normalized ReLU self-attention mechanism and a block-diagonal weight structure, deriving finite-time convergence results that first learn and then while preserving prior knowledge. The main theoretical contributions include explicit weight-growth and loss-bounds across two stages and a spectral interpretation linking attention weight eigenvalues to the learned features. Experimental validation on synthetic data, Counterfact, and HotpotQA supports the theory and reveals consistent spectral patterns: small eigenvalues encode elementary knowledge and large eigenvalues encode specialized knowledge, aligning with the proposed framework and offering insights into transformer optimization dynamics.

Abstract

Transformers may exhibit two-stage training dynamics during the real-world training process. For instance, when training GPT-2 on the Counterfact dataset, the answers progress from syntactically incorrect to syntactically correct to semantically correct. However, existing theoretical analyses hardly account for this feature-level two-stage phenomenon, which originates from the disentangled two-type features like syntax and semantics. In this paper, we theoretically demonstrate how the two-stage training dynamics potentially occur in transformers. Specifically, we analyze the feature learning dynamics induced by the aforementioned disentangled two-type feature structure, grounding our analysis in a simplified yet illustrative setting that comprises a normalized ReLU self-attention layer and structured data. Such disentanglement of feature structure is general in practice, e.g., natural languages contain syntax and semantics, and proteins contain primary and secondary structures. To our best knowledge, this is the first rigorous result regarding a feature-level two-stage optimization process in transformers. Additionally, a corollary indicates that such a two-stage process is closely related to the spectral properties of the attention weights, which accords well with our empirical findings.

Paper Structure

This paper contains 27 sections, 30 theorems, 265 equations, 9 figures, 2 tables.

Key Result

Theorem 1

In the elementary stage with $\eta_1 =\Theta(1)$ and $t_1 = \frac{1}{4 \eta_1 \lambda}$ where $\lambda$ denotes regularization coefficients. With Assumption ass:choice-hyperparam, number of training prompts $N = \Theta\left(\text{Poly}(d)\right)$ and initial weights $V_0 \xrightarrow{} \mathbf{0}_{d (a.2) With random and small noise weight, the training loss of nonlinear separable component $\mat

Figures (9)

  • Figure 1: Two-stage learning of syntactic and semantic information on Counterfact dataset.
  • Figure 2: Overview of disentangled feature structure.
  • Figure 3: Composite nonlinear classification.
  • Figure 4: Summary of Two-stage Learning.
  • Figure 5: Two-stage Learning of Component $\mathcal{P}$ and $\mathcal{Q}$ on Theoretical Synthetic Data.
  • ...and 4 more figures

Theorems & Definitions (36)

  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Theorem 4
  • Corollary 4
  • Lemma 5: Hoeffding's Inequality for General Bounded Random Variables, cite HDP p16
  • Lemma 6: Bernstein's Inequality for Bounded Random Variables, cite <concentration.pdf>, lemma 7.37
  • Lemma 7: Norm of Matrix with Gaussian Entries, cite HDP p85
  • Lemma 8: Standard Gaussian Concentration Inequality
  • Lemma 9: Chernoff Bound for Guassian Variables
  • ...and 26 more