Table of Contents
Fetching ...

Understanding Transformer Reasoning Capabilities via Graph Algorithms

Clayton Sanford, Bahare Fatemi, Ethan Hall, Anton Tsitsulin, Mehran Kazemi, Jonathan Halcrow, Bryan Perozzi, Vahab Mirrokni

TL;DR

This work analyzes transformer reasoning on graph problems under realistic parameter regimes, introducing a representational hierarchy that links task hardness to depth, width, and pause tokens and connects transformers to MPC. It proves that logarithmic-depth transformers can efficiently solve parallelizable tasks, while retrieval tasks admit highly compact single-layer solutions; shortest-path-like problems demand broader scaling. Empirically, GraphQA experiments show transformers can outperform GNNs on global graph reasoning, and trained transformers beat prompt-based LLMs, highlighting the practical potential of parameter-efficient algorithmic reasoning. The study clarifies when transformers excel at global versus local graph reasoning and sets the stage for further scaling analyses and extensions beyond graph tasks.

Abstract

Which transformer scaling regimes are able to perfectly solve different classes of algorithmic problems? While tremendous empirical advances have been attained by transformer-based neural networks, a theoretical understanding of their algorithmic reasoning capabilities in realistic parameter regimes is lacking. We investigate this question in terms of the network's depth, width, and number of extra tokens for algorithm execution. Our novel representational hierarchy separates 9 algorithmic reasoning problems into classes solvable by transformers in different realistic parameter scaling regimes. We prove that logarithmic depth is necessary and sufficient for tasks like graph connectivity, while single-layer transformers with small embedding dimensions can solve contextual retrieval tasks. We also support our theoretical analysis with ample empirical evidence using the GraphQA benchmark. These results show that transformers excel at many graph reasoning tasks, even outperforming specialized graph neural networks.

Understanding Transformer Reasoning Capabilities via Graph Algorithms

TL;DR

This work analyzes transformer reasoning on graph problems under realistic parameter regimes, introducing a representational hierarchy that links task hardness to depth, width, and pause tokens and connects transformers to MPC. It proves that logarithmic-depth transformers can efficiently solve parallelizable tasks, while retrieval tasks admit highly compact single-layer solutions; shortest-path-like problems demand broader scaling. Empirically, GraphQA experiments show transformers can outperform GNNs on global graph reasoning, and trained transformers beat prompt-based LLMs, highlighting the practical potential of parameter-efficient algorithmic reasoning. The study clarifies when transformers excel at global versus local graph reasoning and sets the stage for further scaling analyses and extensions beyond graph tasks.

Abstract

Which transformer scaling regimes are able to perfectly solve different classes of algorithmic problems? While tremendous empirical advances have been attained by transformer-based neural networks, a theoretical understanding of their algorithmic reasoning capabilities in realistic parameter regimes is lacking. We investigate this question in terms of the network's depth, width, and number of extra tokens for algorithm execution. Our novel representational hierarchy separates 9 algorithmic reasoning problems into classes solvable by transformers in different realistic parameter scaling regimes. We prove that logarithmic depth is necessary and sufficient for tasks like graph connectivity, while single-layer transformers with small embedding dimensions can solve contextual retrieval tasks. We also support our theoretical analysis with ample empirical evidence using the GraphQA benchmark. These results show that transformers excel at many graph reasoning tasks, even outperforming specialized graph neural networks.
Paper Structure (37 sections, 34 theorems, 77 equations, 8 figures, 2 tables)

This paper contains 37 sections, 34 theorems, 77 equations, 8 figures, 2 tables.

Key Result

Theorem 1

For constant $\delta, \epsilon >0$, any $R$-round MPC protocol with $N$ machines with $O(N^\delta)$ bits of local memory each can be simulated by a transformer of depth $L = O(R)$ and embedding dimension $m = O(N^{\delta + \epsilon})$.

Figures (8)

  • Figure 1: The graph encoding scheme employed in our theoretical and empirical analysis that presents a graph reasoning task (e.g. connectivity) as a tokenized input to a standard transformer model.
  • Figure 2: A summary of the theoretical hierarchy of \ref{['sec:hierarchy']} that visualizes which type of graph reasoning tasks can be solved in which transformer scaling regime ($\mathsf{Depth1}$ ($\mathsf{D1}$), $\mathsf{LogDepth}$ ($\mathsf{LD}$), $\mathsf{LogDepthWide}$ ($\mathsf{LDW}$) and $\mathsf{LogDepthPause}$ ($\mathsf{LDP}$)).
  • Figure 3: Accuracy of a variety of trained transformers and GNNs on the connectivity task.
  • Figure 4: The neighborhood routing structure.
  • Figure 5: The constant diameter graph construction for $r=3$, $A=(1,0,1)$ and $B=(1,1,0)$. The source $1$ and sink $11$ are connected which is equivalent to $\textsc{Disj}(A,B)=1$ by construction.
  • ...and 3 more figures

Theorems & Definitions (64)

  • Theorem 1: Simplified version of \ref{['thm:tight-mpc']}
  • Theorem 2
  • Theorem 3
  • Theorem 4
  • Theorem 5
  • Theorem 6
  • Theorem 7
  • Definition 1
  • Theorem 8: Formal version of \ref{['thm:tight-mpc-informal']}; transformers simulate MPC
  • Definition 2
  • ...and 54 more