Table of Contents
Fetching ...

Transformers Struggle to Learn to Search

Abulhair Saparov, Srushti Pawar, Shreyas Pimpalgaonkar, Nitish Joshi, Richard Yuanzhe Pang, Vishakh Padmakumar, Seyed Mehran Kazemi, Najoung Kim, He He

TL;DR

This paper shows that transformers can learn to perform graph search when trained on a carefully constructed, high-coverage distribution for DAGs, using graph connectivity as a testbed. It introduces a mechanistic interpretability approach to extract the underlying algorithm, revealing a parallel path-merging mechanism where each vertex’s embedding encodes reachable-sets that expand across layers. However, the authors demonstrate that learning robustness deteriorates with larger graphs and that simply scaling model size or using in-context prompts does not resolve this limitation. The work suggests that future progress may require alternative training strategies or architectures, and provides a framework to reveal the internal algorithms that transformers learn for reasoning tasks.

Abstract

Search is an ability foundational in many important tasks, and recent studies have shown that large language models (LLMs) struggle to perform search robustly. It is unknown whether this inability is due to a lack of data, insufficient model parameters, or fundamental limitations of the transformer architecture. In this work, we use the foundational graph connectivity problem as a testbed to generate effectively limitless high-coverage data to train small transformers and test whether they can learn to perform search. We find that, when given the right training distribution, the transformer is able to learn to search. We analyze the algorithm that the transformer has learned through a novel mechanistic interpretability technique that enables us to extract the computation graph from the trained model. We find that transformers perform search at every vertex in parallel: For each vertex in the input graph, transformers compute the set of vertices reachable from that vertex. Each layer then progressively expands these sets, allowing the model to search over a number of vertices exponential in $n_{\text{layers}}$. However, we find that as the input graph size increases, the transformer has greater difficulty in learning the task. This difficulty is not resolved even as the number of parameters is increased, suggesting that increasing model scale will not lead to robust search abilities. We also find that performing search in-context (i.e., chain-of-thought) does not resolve this inability to learn to search on larger graphs.

Transformers Struggle to Learn to Search

TL;DR

This paper shows that transformers can learn to perform graph search when trained on a carefully constructed, high-coverage distribution for DAGs, using graph connectivity as a testbed. It introduces a mechanistic interpretability approach to extract the underlying algorithm, revealing a parallel path-merging mechanism where each vertex’s embedding encodes reachable-sets that expand across layers. However, the authors demonstrate that learning robustness deteriorates with larger graphs and that simply scaling model size or using in-context prompts does not resolve this limitation. The work suggests that future progress may require alternative training strategies or architectures, and provides a framework to reveal the internal algorithms that transformers learn for reasoning tasks.

Abstract

Search is an ability foundational in many important tasks, and recent studies have shown that large language models (LLMs) struggle to perform search robustly. It is unknown whether this inability is due to a lack of data, insufficient model parameters, or fundamental limitations of the transformer architecture. In this work, we use the foundational graph connectivity problem as a testbed to generate effectively limitless high-coverage data to train small transformers and test whether they can learn to perform search. We find that, when given the right training distribution, the transformer is able to learn to search. We analyze the algorithm that the transformer has learned through a novel mechanistic interpretability technique that enables us to extract the computation graph from the trained model. We find that transformers perform search at every vertex in parallel: For each vertex in the input graph, transformers compute the set of vertices reachable from that vertex. Each layer then progressively expands these sets, allowing the model to search over a number of vertices exponential in . However, we find that as the input graph size increases, the transformer has greater difficulty in learning the task. This difficulty is not resolved even as the number of parameters is increased, suggesting that increasing model scale will not lead to robust search abilities. We also find that performing search in-context (i.e., chain-of-thought) does not resolve this inability to learn to search on larger graphs.

Paper Structure

This paper contains 42 sections, 7 equations, 19 figures, 1 algorithm.

Figures (19)

  • Figure 1: (top left) Example of a search example on a directed acyclic graph and (top right) the corresponding transformer input and output label. (bottom) An equivalent proof search problem in implicational propositional logic, rendered in natural language.
  • Figure 2: Accuracy of model with a maximum input graph size of $41$ vertices trained on 883M examples from the naïve distribution, vs star distribution, vs the balanced distribution with lookaheads $L \le 20$ (which is the maximum for the input size), vs the balanced distribution with $L \le 12$ for the last row. All evaluation is on held-out examples.
  • Figure 3: Overview of method to reconstruct the computation graph from a transformer for a specific input.
  • Figure 4: Visualization of the exponential path-merging algorithm, showcasing the layer-by-layer computation of the reachability of vertex 9 from vertex 1. We hypothesize that transformers learn this algorithm to search. In this algorithm, each token corresponding to a vertex stores information about which other vertices are reachable from this vertex (or from which vertices is this vertex reachable). For example, in layer 3, the model knows that vertex 3 is reachable from 1, and that 5 is reachable from 3, and computes that 5 is reachable from 1, as shown in the input to layer 4. We posit the model performs this computation for all vertices simultaneously.
  • Figure 5: (top) The proportion of examples for which the path-merging algorithm was identified in the computation graph, as reconstructed using our mechanistic interpretability analysis. Each cell contains a random held-out sample of 100 examples. We perform our analysis on the same models as in Section \ref{['sec:sensitivity_results']} (and Figure \ref{['fig:sensitivity_results']}). A randomly-initialized (untrained) model is shown in the last row as the control. (bottom) The proportion of path-merge operations that are "maximal," averaged over 100 random examples. We say a path-merge operation is maximal if it is merging the largest available reachable sets. This is in contrast with a suboptimal path-merge operation where one or both reachable sets are not the largest available at that layer.
  • ...and 14 more figures