Table of Contents
Fetching ...

When can transformers reason with abstract symbols?

Enric Boix-Adsera, Omid Saremi, Emmanuel Abbe, Samy Bengio, Etai Littwin, Joshua Susskind

TL;DR

The paper formalizes relational reasoning with abstract symbols through template tasks and proves that transformer architectures can learn abstract relations and generalize to unseen symbols when trained with sufficiently diverse data, in contrast to classical MLPs which fail to generalize. It provides a kernel-theoretic analysis of transformers via a transformer random features kernel $K_{ ext{trans}}$ and shows universality under disjoint-template and data-diversity conditions, then introduces a simple per-head parametrization ${W_KW_Q^T} + aI$ to improve data efficiency. The authors also address the copying problem in next-token-prediction settings by adding an attention-modulated skip connection, improving generalization to unseen symbols. Empirically, the proposed modifications yield substantial data-efficiency gains on template tasks and improvements in language-modeling tasks like GPT-2 fine-tuning. Overall, the work outlines when transformers can robustly reason with abstract symbols and offers practical architectural tweaks to enhance data efficiency for relational reasoning.

Abstract

We investigate the capabilities of transformer models on relational reasoning tasks. In these tasks, models are trained on a set of strings encoding abstract relations, and are then tested out-of-distribution on data that contains symbols that did not appear in the training dataset. We prove that for any relational reasoning task in a large family of tasks, transformers learn the abstract relations and generalize to the test set when trained by gradient descent on sufficiently large quantities of training data. This is in contrast to classical fully-connected networks, which we prove fail to learn to reason. Our results inspire modifications of the transformer architecture that add only two trainable parameters per head, and that we empirically demonstrate improve data efficiency for learning to reason.

When can transformers reason with abstract symbols?

TL;DR

The paper formalizes relational reasoning with abstract symbols through template tasks and proves that transformer architectures can learn abstract relations and generalize to unseen symbols when trained with sufficiently diverse data, in contrast to classical MLPs which fail to generalize. It provides a kernel-theoretic analysis of transformers via a transformer random features kernel and shows universality under disjoint-template and data-diversity conditions, then introduces a simple per-head parametrization to improve data efficiency. The authors also address the copying problem in next-token-prediction settings by adding an attention-modulated skip connection, improving generalization to unseen symbols. Empirically, the proposed modifications yield substantial data-efficiency gains on template tasks and improvements in language-modeling tasks like GPT-2 fine-tuning. Overall, the work outlines when transformers can robustly reason with abstract symbols and offers practical architectural tweaks to enhance data efficiency for relational reasoning.

Abstract

We investigate the capabilities of transformer models on relational reasoning tasks. In these tasks, models are trained on a set of strings encoding abstract relations, and are then tested out-of-distribution on data that contains symbols that did not appear in the training dataset. We prove that for any relational reasoning task in a large family of tasks, transformers learn the abstract relations and generalize to the test set when trained by gradient descent on sufficiently large quantities of training data. This is in contrast to classical fully-connected networks, which we prove fail to learn to reason. Our results inspire modifications of the transformer architecture that add only two trainable parameters per head, and that we empirically demonstrate improve data efficiency for learning to reason.
Paper Structure (63 sections, 29 theorems, 127 equations, 24 figures)

This paper contains 63 sections, 29 theorems, 127 equations, 24 figures.

Key Result

Theorem 1.1

For any regression template task, a wide-enough transformer architecture trained by gradient flow on sufficiently many samples generalizes on unseen symbols.

Figures (24)

  • Figure 1: (a,b) Variable names in the test data never appear in the train data (indicated by lower/upper-case names). (c) Remarkably, as the training set size increases, the LLM's ability to reason outside of its training data improves, as it learns to use the relations between the variable names to classify, instead of simply memorizing the training data. Our theory motivates a modified transformer architecture (see Observation \ref{['obs:wkwq-mod']}), which solves the reasoning task with less training data. Details in Appendix \ref{['app:teaser-details']}.
  • Figure 2: (a,b) The labels are symbols. (c) We propose a modified that transformer learns the reasoning task with less data (see Observation \ref{['obs:wkwq-mod']} and Theorem \ref{['thm:informal-copy-wvwo-success']}). Details in Appendix \ref{['app:teaser-details']}.
  • Figure 3: Illustration of structure of $\hat{{\boldsymbol K}}$ and ${\boldsymbol N}$ for the same/different task, which has $r = 2$ templates ${\boldsymbol z}_1 = \alpha\alpha$ and ${\boldsymbol z}_2 = \alpha\beta$. As the sample diversity $\rho$ increases and the number of samples $n$ increases, the empirical kernel matrix $\hat{{\boldsymbol K}} \in \mathbb{R}^{n \times n}$ becomes approximately $(r \times r)$-block-structured, and within each block most of the entries are given by ${\boldsymbol N} \in \mathbb{R}^{r \times r}$; exceptions where this is not true, including the diagonals, are drawn in black. Furthermore, the spectrum of $\hat{{\boldsymbol K}}$ is increasingly determined by the spectrum of ${\boldsymbol N}$, and if ${\boldsymbol N}$ is nonsingular then the top eigenspace increasingly aligns with the span of the indicator vectors on $\mathcal{I}_1,\ldots,\mathcal{I}_r$.
  • Figure 4: (a) Transformers fail on the copying task as embedding dimension $d_{emb}$ grows (Theorem \ref{['thm:copy-failure']}); (b) Success when reparametrizing ${\boldsymbol W}_V{\boldsymbol W}_O^T$ as ${\boldsymbol W}_V{\boldsymbol W}_O^T+b{\boldsymbol I}$ (Theorem \ref{['thm:copy-success']}). Details in Appendix \ref{['app:teaser-details']}.
  • Figure 5: Perplexity of GPT-2 trained from random initialization with Adam learning rate 3e-4 for 20 epochs on Wikitext (smaller perplexity is better). GPT-2 has 117M parameters, and we add an extra 288 parameters (2 per head). Interestingly, even though the task is Wikipedia modeling, and therefore is not a pure reasoning task, the transformer modifications still give an improvement.
  • ...and 19 more figures

Theorems & Definitions (79)

  • Theorem 1.1: Informal Theorem \ref{['thm:transformers-succeed-at-template']}
  • Theorem 1.3: Informal Theorem \ref{['thm:copy-failure']}
  • Theorem 1.4: Informal Theorem \ref{['thm:copy-success']}
  • Definition 2.1
  • Definition 2.2
  • Definition 2.3
  • Proposition 3.1: How kernel gradient flow generalizes; see e.g., welling2013kernel.
  • Definition 3.2
  • Definition 3.3
  • Theorem 3.4: Transformers generalize on unseen symbols
  • ...and 69 more