Table of Contents
Fetching ...

When Can Transformers Count to n?

Gilad Yehudai, Haim Kaplan, Asma Ghandeharioun, Mor Geva, Amir Globerson

TL;DR

This work focuses on very simple counting tasks, that involve counting how many times a token in the vocabulary have appeared in a string, and shows that if the dimension of the transformer state is linear in the context length, this task can be solved.

Abstract

Large language models based on the transformer architectures can solve highly complex tasks. But are there simple tasks that such models cannot solve? Here we focus on very simple counting tasks, that involve counting how many times a token in the vocabulary have appeared in a string. We show that if the dimension of the transformer state is linear in the context length, this task can be solved. However, the solution we propose does not scale beyond this limit, and we provide theoretical arguments for why it is likely impossible for a size limited transformer to implement this task. Our empirical results demonstrate the same phase-transition in performance, as anticipated by the theoretical argument. Our results demonstrate the importance of understanding how transformers can solve simple tasks.

When Can Transformers Count to n?

TL;DR

This work focuses on very simple counting tasks, that involve counting how many times a token in the vocabulary have appeared in a string, and shows that if the dimension of the transformer state is linear in the context length, this task can be solved.

Abstract

Large language models based on the transformer architectures can solve highly complex tasks. But are there simple tasks that such models cannot solve? Here we focus on very simple counting tasks, that involve counting how many times a token in the vocabulary have appeared in a string. We show that if the dimension of the transformer state is linear in the context length, this task can be solved. However, the solution we propose does not scale beyond this limit, and we provide theoretical arguments for why it is likely impossible for a size limited transformer to implement this task. Our empirical results demonstrate the same phase-transition in performance, as anticipated by the theoretical argument. Our results demonstrate the importance of understanding how transformers can solve simple tasks.
Paper Structure (25 sections, 7 theorems, 9 equations, 3 figures)

This paper contains 25 sections, 7 theorems, 9 equations, 3 figures.

Key Result

Theorem 4.1

For the Query Count problem and any context length $n > 0$, if $d > 2m$, there exists a transformer that solves it, which has one layer, one head, and an MLP with $d$ neurons.

Figures (3)

  • Figure 1: (a) Solving QC using a histogram (for $d>m$). To count the number of tokens with $x_i=4$, we assume each token is embedded to the standard basis (this can be done because $d>m$, and sum these vectors across all input tokens. This results in a histogram of the inputs, and the $4^{th}$ element can be extracted using a simple "Extraction MLP". (b) Solving QC using CountAttend: this solution works for all $d$, but requires an MLP for inverting numbers, and we show that this MLP need to be of size $n$ (which can be prohibitive). To count the number of tokens with $x_i=4$, the last token attends to the others such that only tokens with $x_i=4$ receive large weights. This results in weights that are non-zero only for $x_i=4$, and the resulting weight on these is the inverse of the count of $4$ (i.e., 0.5 in this case). Then this inverse is moved to the last element of the value vector, using a positional embedding coordinate that is $1$ only for last token $n$. Finally, the inverse count needs to be inverted to get the desired count, and this requires the "Inversion MLP".
  • Figure 2: (a) The threshold vocabulary size at which counting accuracy drops below $80\%$. Results shown for two counting tasks. (b) Results for the QC task when using Gemini 1.5. The x axis is the vocabulary size (i.e., the number of different tokens used in each sequence), the y axis is average absolute error over $100$ repetitions. The "Binary Baseline" curve shows results when using just two tokens, but at the same sequence length used for the "Variable Vocab Size" curve. Standard errors also shown in shade.
  • Figure 3: Evaluation of Gemini 1.5 on the MFC task. See Sec. \ref{['sec:gemini_on_mfc']}.

Theorems & Definitions (12)

  • Theorem 4.1
  • Theorem 4.2
  • proof
  • Proposition 4.3
  • Lemma 4.4
  • proof : Proof of Lemma \ref{['lem:1/x approx']}
  • Theorem 5.1
  • Theorem 5.2
  • proof : Proof of Proposition \ref{['prop:softmax construction']}
  • proof : Proof of Thm. \ref{['thm:MFE lower bound']}
  • ...and 2 more