Table of Contents
Fetching ...

Transformers meet Neural Algorithmic Reasoners

Wilfried Bounsi, Borja Ibarz, Andrew Dudzik, Jessica B. Hamrick, Larisa Markeeva, Alex Vitvitskyi, Razvan Pascanu, Petar Veličković

TL;DR

Transformers struggle with precise, algorithmic reasoning, especially out-of-distribution. The authors propose TransNAR, a hybrid model that fuses a decoder-only Transformer with a frozen, pre-trained graph neural network-based NAR via cross-attention to NAR embeddings, enabling robust algorithmic computation on CLRS-Text. Empirically, TransNAR yields significant out-of-distribution gains and better shape/generalisation properties over Transformer baselines, while revealing limitations and avenues for distillation to unimodal models. This work demonstrates a viable path to combining strong language understanding with robust algorithmic reasoning by integrating specialized symbolic-like modules into LLMs, with broad implications for reasoning-heavy tasks.

Abstract

Transformers have revolutionized machine learning with their simple yet effective architecture. Pre-training Transformers on massive text datasets from the Internet has led to unmatched generalization for natural language understanding (NLU) tasks. However, such language models remain fragile when tasked with algorithmic forms of reasoning, where computations must be precise and robust. To address this limitation, we propose a novel approach that combines the Transformer's language understanding with the robustness of graph neural network (GNN)-based neural algorithmic reasoners (NARs). Such NARs proved effective as generic solvers for algorithmic tasks, when specified in graph form. To make their embeddings accessible to a Transformer, we propose a hybrid architecture with a two-phase training procedure, allowing the tokens in the language model to cross-attend to the node embeddings from the NAR. We evaluate our resulting TransNAR model on CLRS-Text, the text-based version of the CLRS-30 benchmark, and demonstrate significant gains over Transformer-only models for algorithmic reasoning, both in and out of distribution.

Transformers meet Neural Algorithmic Reasoners

TL;DR

Transformers struggle with precise, algorithmic reasoning, especially out-of-distribution. The authors propose TransNAR, a hybrid model that fuses a decoder-only Transformer with a frozen, pre-trained graph neural network-based NAR via cross-attention to NAR embeddings, enabling robust algorithmic computation on CLRS-Text. Empirically, TransNAR yields significant out-of-distribution gains and better shape/generalisation properties over Transformer baselines, while revealing limitations and avenues for distillation to unimodal models. This work demonstrates a viable path to combining strong language understanding with robust algorithmic reasoning by integrating specialized symbolic-like modules into LLMs, with broad implications for reasoning-heavy tasks.

Abstract

Transformers have revolutionized machine learning with their simple yet effective architecture. Pre-training Transformers on massive text datasets from the Internet has led to unmatched generalization for natural language understanding (NLU) tasks. However, such language models remain fragile when tasked with algorithmic forms of reasoning, where computations must be precise and robust. To address this limitation, we propose a novel approach that combines the Transformer's language understanding with the robustness of graph neural network (GNN)-based neural algorithmic reasoners (NARs). Such NARs proved effective as generic solvers for algorithmic tasks, when specified in graph form. To make their embeddings accessible to a Transformer, we propose a hybrid architecture with a two-phase training procedure, allowing the tokens in the language model to cross-attend to the node embeddings from the NAR. We evaluate our resulting TransNAR model on CLRS-Text, the text-based version of the CLRS-30 benchmark, and demonstrate significant gains over Transformer-only models for algorithmic reasoning, both in and out of distribution.
Paper Structure (14 sections, 3 equations, 6 figures, 1 table)

This paper contains 14 sections, 3 equations, 6 figures, 1 table.

Figures (6)

  • Figure 1: Our TransNAR architecture, with its direct synergy of Transformers and Neural Algorithmic Reasoners, yields clear improvements in out-of-distribution reasoning across wide categories of algorithmic tasks in CLRS-Text markeeva2024clrstext, a textual version of the CLRS-30 benchmark clrs30. Here, the $x$-axis indicates one of the eight algorithmic families of CLRS-30, and the $y$-axis spans the average execution accuracy across a dataset of out-of-distribution examples. TransNAR enables emerging capabilities in the particular out-of-distribution regime depicted here, with over 20% absolute improvement in several of the algorithm classes.
  • Figure 2: Augmenting LLMs with algorithmic reasoning: a bird's eye view of TransNAR. A large language model (LLM) consumes input tokens and produces output tokens, as common for a unimodal Transformer. The neural algorithmic reasoner (NAR) module is a graph neural network (GNN) pre-trained to execute various algorithmic computation on a collection of graph-based inputs generalist---the pre-training pipeline is denoted by faded arrows. Throughout its forward pass, the Transformer may access the embeddings computed by the NAR, by leveraging cross-attention (trained by learnable "glue" weights).
  • Figure 3: TransNAR hybrid architecture. Similar to flamingo, we interleave existing Transformer layers with gated cross-attention layers which enable information to flow from the NAR to the Transformer. We generate queries from tokens while we obtain keys and values from nodes and edges of the graph. The node and edge embeddings are obtained by running the NAR on the graph version of the reasoning task to be solved. When experimenting with pre-trained Transformers, we initially close the cross-attention gate, in order to fully preserve the language model's internal knowledge at the beginning of training.
  • Figure 4: TransNAR significantly outperforms the baseline Transformer. We compare TransNAR to its corresponding Transformer baseline on various algorithms and for various input sizes: $12$ is the largest size in-distribution. The other two sizes tested---$10$ and $14$---are out-of-distribution, with the former testing interpolation and the latter extrapolation. Note that in-distribution generalisation is much easier for Transformers, and as such, we have modified the $y$-axis for this setting only to the $[0.7, 1.0]$ range. It is evident that, on most algorithmic tasks of interest, the TransNAR is capable of outperforming its baseline Transformer. Additionally, we see that this advantage is consistent across both training regimes: initial training and finetuning. The metric used is the CLRS score. Each model was trained with 4 random seeds. Error bars indicate $\pm 1$ standard deviation.
  • Figure 5: Shape Score: The TransNAR significantly outperforms its baseline in terms of producing correct shapes. This score sheds light on an obvious failure model of regular Transformers out-of-distribution: they fail to capture the seemingly trivial dependency between input size and output size, and so irrespective of the complexity of the algorithm itself. The TransNAR model manages to considerably alleviate this problem (with many emerging gains), albeit, these gains do not always lead to perfect scores, implying a fruitful direciton for future research.
  • ...and 1 more figures