Table of Contents
Fetching ...

Understanding the Staged Dynamics of Transformers in Learning Latent Structure

Rohan Saha, Farzane Aminmansour, Alona Fyshe

TL;DR

This work investigates how transformers acquire latent structure in a controlled setting by studying the Alchemy benchmark with a small decoder-only transformer. It analyzes latent-structure learning across three formulations—latent-structure discovery under partial support, composition, and decomposition—by factorizing accuracy into interpretable events and formalizing a multiplicative decomposition $P[C] = P[A]\,P[B|A]\,P[C|A\cap B]$ to track sub-skills. The key findings show staged, coarse-to-fine learning with plateaus and jumps, an adjacency/bias effect that can momentarily misdirect learning, and a fundamental asymmetry: composition remains robust to increasing task complexity while decomposition exhibits a bottleneck as complexity grows. These results provide a granular, mechanistic view of how latent structures are learned in transformers, with implications for training strategies and benchmark design; the authors also release their code for broader reuse.

Abstract

While transformers can discover latent structure from context, the dynamics of how they acquire different components of the latent structure remain poorly understood. In this work, we use the Alchemy benchmark, to investigate the dynamics of latent structure learning. We train a small decoder-only transformer on three task variants: 1) inferring missing rules from partial contextual information, 2) composing simple rules to solve multi-step sequences, and 3) decomposing complex multi-step examples to infer intermediate steps. By factorizing each task into interpretable events, we show that the model acquires capabilities in discrete stages, first learning the coarse grained rules, before learning the complete latent structure. We also identify a crucial asymmetry, where the model can compose fundamental rules robustly, but struggles to decompose complex examples to discover the fundamental rules. These findings offer new insights into understanding how a transformer model learns latent structures, providing a granular view of how these capabilities evolve during training.

Understanding the Staged Dynamics of Transformers in Learning Latent Structure

TL;DR

This work investigates how transformers acquire latent structure in a controlled setting by studying the Alchemy benchmark with a small decoder-only transformer. It analyzes latent-structure learning across three formulations—latent-structure discovery under partial support, composition, and decomposition—by factorizing accuracy into interpretable events and formalizing a multiplicative decomposition to track sub-skills. The key findings show staged, coarse-to-fine learning with plateaus and jumps, an adjacency/bias effect that can momentarily misdirect learning, and a fundamental asymmetry: composition remains robust to increasing task complexity while decomposition exhibits a bottleneck as complexity grows. These results provide a granular, mechanistic view of how latent structures are learned in transformers, with implications for training strategies and benchmark design; the authors also release their code for broader reuse.

Abstract

While transformers can discover latent structure from context, the dynamics of how they acquire different components of the latent structure remain poorly understood. In this work, we use the Alchemy benchmark, to investigate the dynamics of latent structure learning. We train a small decoder-only transformer on three task variants: 1) inferring missing rules from partial contextual information, 2) composing simple rules to solve multi-step sequences, and 3) decomposing complex multi-step examples to infer intermediate steps. By factorizing each task into interpretable events, we show that the model acquires capabilities in discrete stages, first learning the coarse grained rules, before learning the complete latent structure. We also identify a crucial asymmetry, where the model can compose fundamental rules robustly, but struggles to decompose complex examples to discover the fundamental rules. These findings offer new insights into understanding how a transformer model learns latent structures, providing a granular view of how these capabilities evolve during training.

Paper Structure

This paper contains 23 sections, 13 equations, 11 figures, 1 table.

Figures (11)

  • Figure 1: Overview of Alchemy chemistry structure and experimental tasks. Middle: Example chemistry with vertices representing stone states connected with bidirectional edges (potions). A chemistry consists of eight stones, and application of potions changes the stone features. Left: Experiment to investigate the staged dynamics of latent structure learning: given a chemistry, all samples for a randomly selected potion pair are withheld (e.g., YELLOW/ORANGE). The model needs to correctly predict the effect of the withheld potions when applied to a query stone. Right: Tasks examining complexity: For composition (top right), each sample consists of the concatenation of all 1-hop (single-step) support transitions for the chemistry, followed by a multi-hop query. For decomposition (bottom right), the model receives the concatenation of all multi-hop transitions for a given hop length in the support, followed by a 1-hop query. Note that in all our experiments, we provide the full enumeration of the stone features in the support and query. Figure adapted from the original Alchemy benchmark wang2021alchemy.
  • Figure 2: (a) Validation accuracy for latent structure learning (withheld potion pair with $hl_{support} = hl_{query} = 1$). Perfect performance demonstrates the model's ability to learn and generalize latent structures, but via distinct plateaus and jumps, indicating the acquisition of the latent structure through various stages. (b) Validation accuracy after factorizing the task into different events / components. Blue: $\mathbb{P}[A]$ (in-support). Orange: $\mathbb{P}[B | A]$ (correct half given in-support). Red: complementary incorrect half accuracy. Green: $\mathbb{P}[C | A \cap B]$ (exact match given correct half). The product of the events denoting the blue, green, and orange curves recreates the overall accuracy in Figure \ref{['fig:4_way_edge_completion_performance']}. We only show 500 epochs because the model reached convergence. Error bars denote the standard error of the mean.
  • Figure 3: Model performance for composing 1-hop transitions to solve multi-hop queries. There is no noticeable difference in model performance with increasing in the value of $hl_{query} \in \{2,3,4,5\}$. The x-axis denotes epochs, and the y-axis denotes the validation accuracy. Performance is averaged over three seeds. The error bars denote the standard error of the mean. We only show the first 500 epochs as all hops reached near-perfect accuracy.
  • Figure 4: Staged learning dynamics for the composition task (first 500 epochs) with $hl_{query} \in \{2,3,4,5\}$. (a) Staged dynamics for 2-hop composition. (b) Staged dynamics for 3-hop composition. (c) Staged dynamics for 4-hop composition. (d) Staged dynamics for 5-hop composition. In each subfigure, "orange" plots $\mathbb{P}[A]$ (in-support), "purple" plots $\mathbb{P}[R\mid A]$ (within reachable stones given in-support), "colored" plots (for each $k=hl_{query}$) $\mathbb{P}[C\mid A\cap R]$ (exact match given the reachable stones); their product recreates the overall accuracy in Figure \ref{['fig:composition_performance_all_hops']}. The x-axis denotes epochs and the y-axis denotes validation accuracy. Error bars show the standard error of the mean.
  • Figure 5: Decomposition results for various $hl_{support} \in \{2,3,4,5\}$. We see delayed convergence with increasing task complexity ($hl_{support}$), where the 2-hop converges the earliest and the 5-hop converges the latest. The x-axis denotes epochs, and the y-axis denotes the validation accuracy. Performance is averaged over three seeds, and the error bars show the standard error of the mean.
  • ...and 6 more figures