How Transformers Learn Causal Structure with Gradient Descent
Eshaan Nichani, Alex Damian, Jason D. Lee
TL;DR
This work investigates how gradient-based training enables transformers to learn latent causal structure in sequential data. By analyzing a simplified two-layer, disentangled transformer on a newly proposed random sequences with causal structure task, it shows that the first attention layer converges to the latent graph adjacency, with the gradient flow encoding mutual information guiding edge recovery. In cases where the causal graph is a tree, induction-head-like behavior emerges as a special case of learning the latent transitions, and multi-head extensions accommodate graphs with multiple parents. The authors provide a rigorous training algorithm and a main theorem guaranteeing close-to-optimal loss and OOD generalization, supported by experiments that demonstrate the recoverability of various causal structures. Overall, the paper offers a mechanistic, information-theoretic account of how gradient descent shapes causal representations in Transformers for in-context learning tasks.
Abstract
The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.
