Table of Contents
Fetching ...

How Transformers Learn Causal Structure with Gradient Descent

Eshaan Nichani, Alex Damian, Jason D. Lee

TL;DR

This work investigates how gradient-based training enables transformers to learn latent causal structure in sequential data. By analyzing a simplified two-layer, disentangled transformer on a newly proposed random sequences with causal structure task, it shows that the first attention layer converges to the latent graph adjacency, with the gradient flow encoding mutual information guiding edge recovery. In cases where the causal graph is a tree, induction-head-like behavior emerges as a special case of learning the latent transitions, and multi-head extensions accommodate graphs with multiple parents. The authors provide a rigorous training algorithm and a main theorem guaranteeing close-to-optimal loss and OOD generalization, supported by experiments that demonstrate the recoverability of various causal structures. Overall, the paper offers a mechanistic, information-theoretic account of how gradient descent shapes causal representations in Transformers for in-context learning tasks.

Abstract

The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.

How Transformers Learn Causal Structure with Gradient Descent

TL;DR

This work investigates how gradient-based training enables transformers to learn latent causal structure in sequential data. By analyzing a simplified two-layer, disentangled transformer on a newly proposed random sequences with causal structure task, it shows that the first attention layer converges to the latent graph adjacency, with the gradient flow encoding mutual information guiding edge recovery. In cases where the causal graph is a tree, induction-head-like behavior emerges as a special case of learning the latent transitions, and multi-head extensions accommodate graphs with multiple parents. The authors provide a rigorous training algorithm and a main theorem guaranteeing close-to-optimal loss and OOD generalization, supported by experiments that demonstrate the recoverability of various causal structures. Overall, the paper offers a mechanistic, information-theoretic account of how gradient descent shapes causal representations in Transformers for in-context learning tasks.

Abstract

The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.
Paper Structure (61 sections, 39 theorems, 312 equations, 9 figures, 2 algorithms)

This paper contains 61 sections, 39 theorems, 312 equations, 9 figures, 2 algorithms.

Key Result

Lemma 1

Let $\theta = (A^{(1)}, A^{(2)})$, and let $\widetilde{\theta} = (\widetilde{A}^{(1)},\widetilde{A}^{(2)}, \widetilde{W}_O)$ be defined in eq:sparsity_pattern. Let $f_\theta = \widetilde{\mathrm{TF}}_{\widetilde{\theta}}$ be a two-layer disentangled transformer parameterized by $\theta$. Then if $\o

Figures (9)

  • Figure 1: Random Sequence with Causal Structure: The causal structure is defined by the graph $\mathcal{G}$, denoted by the arrows. In this figure, $p(1) = \emptyset$, $p(2) = \{1\}$, $p(3) = \{1\}$, $p(4) = \{2\}$ and $p(5) = \{3\}$. Sequences are generated by sampling $\pi \sim P_\pi$, $s_1 \sim \mu_\pi$, $s_2 \sim \pi(\cdot|s_1)$, $s_3 \sim \pi(\cdot|s_1)$, $s_4 \sim \pi(\cdot|s_2)$, $s_5 \sim \pi(\cdot|s_3)$, and finally $s_6 \sim \mathop{\mathrm{Unif}}\nolimits([S])$. The target $y$ for this sequence is drawn from $\pi(\cdot|s_6)$.
  • Figure 2: The Weights of a Trained Transformer: We plot the weights of a two layer disentangled transformer trained on \ref{['task:single_parent']} with $S=10$ and $T=20$ when the causal graph is the in-context learning graph where $p(2i) = 2i-1$ for all $i > 0$. All entries of $A^{(1)}, A^{(2)}, W_O$ remain small except the three blocks highlighted in red. The highlighted block in $A^{(1)}$ converges to the adjacency matrix of the causal graph $\mathcal{G}$, and the highlighted blocks in $A^{(2)},W_O$ converge to the identity matrix $I_S$.
  • Figure 3: Understanding the Forward Pass: The solid arrows represent the causal graph $\mathcal{G}$ defined in \ref{['fig:graph_example']} and $h^{(0)}$ denotes the unmodified input sequence. The first attention reverses this causal pattern, as every token attends to its parent (solid arrows). It then appends this parent token to the residual stream (dashed arrows). In the second attention layer, each token $i$ attends to all previous tokens $j$ whose parent token $p(j)$ has the same value, i.e. $s_i = s_{p(j)}$ (solid arrows), and appends the average of these tokens into the residual stream (dashed arrows). Finally, the transformer returns the third entry in the last column (red box), which is the average of all of the tokens whose parent token has the same value as the last token.
  • Figure 4: By the data processing inequality, $A^{(1)}_{i, p(i)}$ grows faster than $A^{(1)}_{i, j}$.
  • Figure 5: Multiple Parents: We show three examples of trained transformers on \ref{['task:multi_parent']} with $k=2,2,3$ respectively. The left column shows the adjacency matrix of the causal graphs $\mathcal{G}$. To their right, we plot the attention patterns $\mathcal{S}(A^{(1)}_i)$ for each head $i$ where $A^{(1)}_i$ is the position-position component of $\widetilde{A}^{(1)}_i$. We see that each attention head learns a single set of parents in the causal graph $\mathcal{G}$, which agrees with \ref{['thm:multi_parent_construction']}. See \ref{['fig:appendix_multi_parent']} for plots of the full matrices $\widetilde{A}^{(1)}_i$.
  • ...and 4 more figures

Theorems & Definitions (85)

  • Definition 1: Causal self-attention head
  • Definition 2: Decoder-based transformer
  • Definition 3: Disentangled Transformer
  • proof
  • Lemma 1
  • Definition 4: Effective Sequence Length
  • Theorem 1: Guarantee for \ref{['alg:training_alg']}
  • Theorem 2: OOD Generalization
  • Definition 5
  • Lemma 2: Data Processing Inequality
  • ...and 75 more