Table of Contents
Fetching ...

How Do Transformers Learn Topic Structure: Towards a Mechanistic Understanding

Yuchen Li, Yuanzhi Li, Andrej Risteski

TL;DR

The paper tackles the problem of mechanistically understanding how transformers acquire topic structure. It develops a tractable analysis using a topic-modeling data distribution (LDA) and a one‑layer Transformer, proving that topic signals can be encoded either in token embeddings or in self‑attention, and identifying a two‑stage learning dynamic where embedding/value patterns emerge first under uniform attention and attention weights align later. The authors provide formal theorems and extensive experiments on synthetic data and Wikipedia to validate the mechanisms, showing block‑structured embeddings and W^V, as well as topic‑biased attention, and they discuss robustness to training settings and cross‑losses. The work offers a principled explanation for topic discovery in contextual representations and informs interpretability and training dynamics of transformers beyond mere empirical probing.

Abstract

While the successes of transformers across many domains are indisputable, accurate understanding of the learning mechanics is still largely lacking. Their capabilities have been probed on benchmarks which include a variety of structured and reasoning tasks -- but mathematical understanding is lagging substantially behind. Recent lines of work have begun studying representational aspects of this question: that is, the size/depth/complexity of attention-based networks to perform certain tasks. However, there is no guarantee the learning dynamics will converge to the constructions proposed. In our paper, we provide fine-grained mechanistic understanding of how transformers learn "semantic structure", understood as capturing co-occurrence structure of words. Precisely, we show, through a combination of mathematical analysis and experiments on Wikipedia data and synthetic data modeled by Latent Dirichlet Allocation (LDA), that the embedding layer and the self-attention layer encode the topical structure. In the former case, this manifests as higher average inner product of embeddings between same-topic words. In the latter, it manifests as higher average pairwise attention between same-topic words. The mathematical results involve several assumptions to make the analysis tractable, which we verify on data, and might be of independent interest as well.

How Do Transformers Learn Topic Structure: Towards a Mechanistic Understanding

TL;DR

The paper tackles the problem of mechanistically understanding how transformers acquire topic structure. It develops a tractable analysis using a topic-modeling data distribution (LDA) and a one‑layer Transformer, proving that topic signals can be encoded either in token embeddings or in self‑attention, and identifying a two‑stage learning dynamic where embedding/value patterns emerge first under uniform attention and attention weights align later. The authors provide formal theorems and extensive experiments on synthetic data and Wikipedia to validate the mechanisms, showing block‑structured embeddings and W^V, as well as topic‑biased attention, and they discuss robustness to training settings and cross‑losses. The work offers a principled explanation for topic discovery in contextual representations and informs interpretability and training dynamics of transformers beyond mere empirical probing.

Abstract

While the successes of transformers across many domains are indisputable, accurate understanding of the learning mechanics is still largely lacking. Their capabilities have been probed on benchmarks which include a variety of structured and reasoning tasks -- but mathematical understanding is lagging substantially behind. Recent lines of work have begun studying representational aspects of this question: that is, the size/depth/complexity of attention-based networks to perform certain tasks. However, there is no guarantee the learning dynamics will converge to the constructions proposed. In our paper, we provide fine-grained mechanistic understanding of how transformers learn "semantic structure", understood as capturing co-occurrence structure of words. Precisely, we show, through a combination of mathematical analysis and experiments on Wikipedia data and synthetic data modeled by Latent Dirichlet Allocation (LDA), that the embedding layer and the self-attention layer encode the topical structure. In the former case, this manifests as higher average inner product of embeddings between same-topic words. In the latter, it manifests as higher average pairwise attention between same-topic words. The mathematical results involve several assumptions to make the analysis tractable, which we verify on data, and might be of independent interest as well.
Paper Structure (58 sections, 15 theorems, 92 equations, 10 figures, 5 tables)

This paper contains 58 sections, 15 theorems, 92 equations, 10 figures, 5 tables.

Key Result

Theorem 1

Suppose the training data follows a topic model data distribution, and the transformer has trainable embedding layer, frozen (uniform) attention scores, and all other components set to identity. Then, the optimal embedding layer of a single layer transformer is such that the inner product of the emb

Figures (10)

  • Figure 1: Embedding weight dot product of models trained on synthetic topic modeling data (Section \ref{['sec:experiments:setup']}). The four plots correspond to different combinations of loss function and optimizer: (left to right) cross-entropy with SGD, cross-entropy with Adam, squared loss with SGD, squared loss with Adam, all using learning rate 0.01. The block-wise pattern verifies our theory in Section \ref{['sec:embedding']}. The 10 blocks correspond to the 10 topics in the data distribution in Section \ref{['sec:setup:topic_modeling']}. In particular, a diagonal pattern is a special case of the block-wise optima that we prove (see Theorem \ref{['thm:optimal_embedding']}).
  • Figure 2: Convergence point of trained ${\bm{W}}^V$ (with $L_2$-regularization) when freezing uniform attention weights and one-hot word embedding. The four plots correspond to different combinations of loss function and optimizer. (Left to right) cross-entropy with SGD, cross-entropy with Adam, squared loss with SGD, squared loss with Adam, all using learning rate 0.01. The block-wise pattern verifies our theory in Section \ref{['sec:attention:value']}. The 10 blocks correspond to the 10 topics in the data distribution. Results are qualitatively similar without $L_2$-regularization, or if we train ${\bm{W}}^K$ and ${\bm{W}}^Q$ instead of freezing them (see Appendix \ref{['sec:appendix:experiments:Wv']}).
  • Figure 3: For a BERT model pre-trained on Wikipedia corpus, the cosine similarity of the word embeddings encodes topical structures, i.e. it is larger if the two words belong to the same topic, and smaller if they belong to different topics. This phenomenon is more pronounced for words that are very likely only under a few topics. In this figure, the nine words fall into three topics: {frog, toad, lizard} are animals, {mozart, beethoven, schubert} are musicians, and {algebra, arithmetic, calculus} are mathematical concepts.
  • Figure 4: Two-stage learning dynamics of a single-layer transformer trained on LDA data distribution. All weight matrices are initialized to random matrices near zero, and simultaneously trained. The learning dynamics naturally exhibits a two-stage phenomenon: in Stage 1 (steps 0-400), the norms of the key matrix ($W^K$, top) and the query matrix ($W^Q$, middle) stay close to 0, while the norm of the value matrix ($W^V$, bottom) increases significantly. In Stage 2 (steps 400-1000), the norms of $W^K$ and $W^Q$ start increasing significantly, while the norm of $W^V$ stays relatively flat. Different curves in the figure correspond to different settings of the hyperparameters as well as different runs in each setting. (See Section \ref{['sec:discussions']} for more details.)
  • Figure 5: Two-stage learning dynamics of a 4-layer, 4-head-per-layer transformer trained on Wikipedia data. All weight matrices (key ${\bm{W}}^K$, query ${\bm{W}}^Q$, value ${\bm{W}}^V$ in each layer) are initialized to random matrices near zero, and simultaneously trained. Each column corresponds to one layer. The top 3 rows plot the trajectories of the Frobenius norms of ${\bm{W}}^K$, ${\bm{W}}^Q$, and ${\bm{W}}^V$ (weights from all heads in the same layer are concatenated together) after each gradient step. The bottom row measures the rotation of ${\bm{W}}^V$, i.e. the cosine distance between ${\bm{W}}^V$ in step $t$ and ${\bm{W}}^V$ in step $(t-10)$. Cosine distance is defined as $\frac{1-cs}{2} \in [0, 1]$, in which $cs$ is the classic cosine similarity. The initial 400 steps of the learning dynamics naturally exhibit an approximately two-stage phenomenon: in Stage 1 (roughly steps 0-100), for all 4 layers, the norms of ${\bm{W}}^K$ and ${\bm{W}}^Q$ stay close to 0, while the norm of ${\bm{W}}^V$ increases significantly and the orientation of ${\bm{W}}^V$ changes rapidly. In Stage 2 (roughly steps 100-400), the norms of ${\bm{W}}^K$'s and ${\bm{W}}^Q$'s start increasing significantly, much later than ${\bm{W}}^V$ matrices do. Different curves in the figure correspond to different settings of the hyperparameters as well as different runs in each setting.
  • ...and 5 more figures

Theorems & Definitions (36)

  • Theorem : Optimal word embedding, informal
  • Theorem : Optimal ${\bm{W}}^V$, informal
  • Theorem : Optimal attention weights, informal
  • Definition 1: Topic-word indicator
  • Remark 1
  • Theorem 1: Optimal token embedding
  • Remark 2
  • Remark 3
  • Theorem 2: Optimal ${\bm{W}}^V$ with mild $L_2$-regularization when freezing uniform attention
  • Theorem 3: Optimal attention weights
  • ...and 26 more