Disentangling Feature Structure: A Mathematically Provable Two-Stage Training Dynamics in Transformers
Zixuan Gong, Shijia Li, Yong Liu, Jiaye Teng
TL;DR
The paper provides a rigorous, feature-level explanation for two-stage training dynamics observed in transformers under ICL prompts by modeling data with disentangled elementary ($\mathcal{P}$) and specialized ($\mathcal{Q}$) features. It analyzes a simplified one-layer transformer with a normalized ReLU self-attention mechanism and a block-diagonal weight structure, deriving finite-time convergence results that first learn $\mathcal{P}$ and then $\mathcal{Q}$ while preserving prior knowledge. The main theoretical contributions include explicit weight-growth and loss-bounds across two stages and a spectral interpretation linking attention weight eigenvalues to the learned features. Experimental validation on synthetic data, Counterfact, and HotpotQA supports the theory and reveals consistent spectral patterns: small eigenvalues encode elementary knowledge and large eigenvalues encode specialized knowledge, aligning with the proposed framework and offering insights into transformer optimization dynamics.
Abstract
Transformers may exhibit two-stage training dynamics during the real-world training process. For instance, when training GPT-2 on the Counterfact dataset, the answers progress from syntactically incorrect to syntactically correct to semantically correct. However, existing theoretical analyses hardly account for this feature-level two-stage phenomenon, which originates from the disentangled two-type features like syntax and semantics. In this paper, we theoretically demonstrate how the two-stage training dynamics potentially occur in transformers. Specifically, we analyze the feature learning dynamics induced by the aforementioned disentangled two-type feature structure, grounding our analysis in a simplified yet illustrative setting that comprises a normalized ReLU self-attention layer and structured data. Such disentanglement of feature structure is general in practice, e.g., natural languages contain syntax and semantics, and proteins contain primary and secondary structures. To our best knowledge, this is the first rigorous result regarding a feature-level two-stage optimization process in transformers. Additionally, a corollary indicates that such a two-stage process is closely related to the spectral properties of the attention weights, which accords well with our empirical findings.
