Table of Contents
Fetching ...

Towards an Understanding of Stepwise Inference in Transformers: A Synthetic Graph Navigation Model

Mikail Khona, Maya Okawa, Jan Hula, Rahul Ramesh, Kento Nishi, Robert Dick, Ekdeep Singh Lubana, Hidenori Tanaka

TL;DR

This work reframes stepwise inference in transformers as graph navigation on DAGs, using Bernoulli and hierarchical graphs to study how intermediate planning affects problem solving. It demonstrates that stepwise inference yields advantages over direct inference, but introduces a stepwise gap tied to training path lengths and a diversity-accuracy trade-off controlled by sampling temperature. A mechanistic analysis reveals a distance-based planning strategy and a simple 1-layer attention surrogate that captures the core behavior, while exemplars enable controllability, compositional generalization, and sensitivity to context. The synthetic framework provides mechanistic hypotheses and a platform to test whether these insights generalize to larger models and more complex tasks, with potential implications for improving interpretability and reliability of stepwise reasoning in AI systems.

Abstract

Stepwise inference protocols, such as scratchpads and chain-of-thought, help language models solve complex problems by decomposing them into a sequence of simpler subproblems. Despite the significant gain in performance achieved via these protocols, the underlying mechanisms of stepwise inference have remained elusive. To address this, we propose to study autoregressive Transformer models on a synthetic task that embodies the multi-step nature of problems where stepwise inference is generally most useful. Specifically, we define a graph navigation problem wherein a model is tasked with traversing a path from a start to a goal node on the graph. Despite is simplicity, we find we can empirically reproduce and analyze several phenomena observed at scale: (i) the stepwise inference reasoning gap, the cause of which we find in the structure of the training data; (ii) a diversity-accuracy tradeoff in model generations as sampling temperature varies; (iii) a simplicity bias in the model's output; and (iv) compositional generalization and a primacy bias with in-context exemplars. Overall, our work introduces a grounded, synthetic framework for studying stepwise inference and offers mechanistic hypotheses that can lay the foundation for a deeper understanding of this phenomenon.

Towards an Understanding of Stepwise Inference in Transformers: A Synthetic Graph Navigation Model

TL;DR

This work reframes stepwise inference in transformers as graph navigation on DAGs, using Bernoulli and hierarchical graphs to study how intermediate planning affects problem solving. It demonstrates that stepwise inference yields advantages over direct inference, but introduces a stepwise gap tied to training path lengths and a diversity-accuracy trade-off controlled by sampling temperature. A mechanistic analysis reveals a distance-based planning strategy and a simple 1-layer attention surrogate that captures the core behavior, while exemplars enable controllability, compositional generalization, and sensitivity to context. The synthetic framework provides mechanistic hypotheses and a platform to test whether these insights generalize to larger models and more complex tasks, with potential implications for improving interpretability and reliability of stepwise reasoning in AI systems.

Abstract

Stepwise inference protocols, such as scratchpads and chain-of-thought, help language models solve complex problems by decomposing them into a sequence of simpler subproblems. Despite the significant gain in performance achieved via these protocols, the underlying mechanisms of stepwise inference have remained elusive. To address this, we propose to study autoregressive Transformer models on a synthetic task that embodies the multi-step nature of problems where stepwise inference is generally most useful. Specifically, we define a graph navigation problem wherein a model is tasked with traversing a path from a start to a goal node on the graph. Despite is simplicity, we find we can empirically reproduce and analyze several phenomena observed at scale: (i) the stepwise inference reasoning gap, the cause of which we find in the structure of the training data; (ii) a diversity-accuracy tradeoff in model generations as sampling temperature varies; (iii) a simplicity bias in the model's output; and (iv) compositional generalization and a primacy bias with in-context exemplars. Overall, our work introduces a grounded, synthetic framework for studying stepwise inference and offers mechanistic hypotheses that can lay the foundation for a deeper understanding of this phenomenon.
Paper Structure (35 sections, 5 equations, 17 figures, 1 table, 3 algorithms)

This paper contains 35 sections, 5 equations, 17 figures, 1 table, 3 algorithms.

Figures (17)

  • Figure 1: Examples of stepwise inference protocols and how they can be cast as a graph navigation problem. (a) Zero-shot chain-of-thought kojima2022large involves asking a model to produce intermediate outputs to perform complex multi-step computations, such as solving the Tower of Hanoi problem. Casting the configurations of the rods in Tower of Hanoi as nodes of a graph, we can see that the problem is essentially traversal over states describing different configurations of the setup to reach the desired configuration (the goal state). (b) Scratchpad nye2021show improves LLMs' ability to perform complex multi-step computations, such as arithmetic, when they write intermediate computation steps to a buffer called a scratchpad.
  • Figure 2: Data generating process. (a) In absence of exemplars. This figure illustrates the step-by-step process of generating a training dataset using a single underlying graph. 1) A directed acyclic graph (DAG) is generated, which can be either hierarchically structured or Bernoulli. 2) A start node and a goal node are selected. 3) All possible paths connecting the start and goal nodes are sampled, and one path is randomly selected. 4) The chosen path is then represented in a task-specific format. (b) In presence of exemplars. The process of generating a training dataset by combining multiple subgraphs (motifs) involves the following. (1.) Start by building a set of Bernoulli directed acyclic graphs (DAGs). (2.) Pick a subset of $K$ of these DAGs $\{g_{i_1}, g_{i_2}, .. g_{i_K} \}$ and connect them together using "ghost edges" to create a chain of motifs $g_{i_1} \mapsto g_{i_2} \mapsto \dots \mapsto g_{i_K}$. (3.) Sample exemplars from every pair of motifs that have been connected by a ghost edge to construct the context. (4.) Now choose a start node $\textcolor{rgb(0,0,204)}{X_s} \in g_{i_1}$ and a goal node $\textcolor{rgb(204,0,0)}{X_g} \in g_{i_K}$ and construct a sequence passing through the whole chain of motifs.
  • Figure 3: Advantage of stepwise inference in graph navigation tasks and stitching: (a) In the Bernoulli DAG, stepwise inference demonstrates an advantage over direct inference in predicting whether given node pairs are connected. (b) This advantage is further pronounced in hierarchical DAGs, where the distances between nodes are greater than in Bernoulli DAGs. (c) The stepwise inference gap arises when the training set contains paths that are shorter than the paths required to connect nodes at test time. (d) The stepwise inference is beneficial when the model must connect paths seen during training: the red, green, and blue paths represent subsets of paths seen during training; we find the model produces paths that combine these subsets during the test phase.
  • Figure 4: Diversity vs. accuracy trade-off for different sampling temperatures of the Transformer model: As the sampling temperature increases, the diversity of paths generated by the model also increases, while the accuracy decreases. This tradeoff is captured by measuring the number of unique valid paths (top panel), indicating that there is an optimal temperature for sampling. The dashed line represents the ground truth path diversity.
  • Figure 5: Model outputs are biased toward shorter paths. We compared the average lengths of ground-truth paths for a specific set of node pairs and the paths produced by the model for these same pairs in the Bernoulli DAG. We observe that the model tends to generate shorter paths than the actual ones. This observation points to a "simplicity bias" in the trained model towards favoring shorter over potentially more accurate or realistic paths.
  • ...and 12 more figures