Table of Contents
Fetching ...

Understanding Factual Recall in Transformers via Associative Memories

Eshaan Nichani, Jason D. Lee, Alberto Bietti

TL;DR

This work analyzes how transformers memorize facts by leveraging associative memories, showing that shallow transformers can achieve near-optimal factual recall by combining linear or MLP-based associative memories with attention. The authors prove storage capacity scales linearly with parameter count and demonstrate that a one-layer transformer can achieve 100% recall on a synthetic task when either the self-attention or the MLP parameter counts scale nearly linearly with the dataset size SR (up to logarithmic factors). They additionally study gradient dynamics, revealing a sequential learning trajectory with an intermediate hallucination stage where the model relies on relations before subjects, and provide information-theoretic lower bounds that match their constructions up to log factors. Empirical validations corroborate the theory, showing linear scaling of memory capacity with model size and memory trade-offs between attention and MLP components. Overall, the paper advances understanding of memorization and factual recall mechanisms in transformers and suggests concrete architectural principles for scalable memory in neural networks.

Abstract

Large language models have demonstrated an impressive ability to perform factual recall. Prior work has found that transformers trained on factual recall tasks can store information at a rate proportional to their parameter count. In our work, we show that shallow transformers can use a combination of associative memories to obtain such near optimal storage capacity. We begin by proving that the storage capacities of both linear and MLP associative memories scale linearly with parameter count. We next introduce a synthetic factual recall task, and prove that a transformer with a single layer of self-attention followed by an MLP can obtain 100% accuracy on the task whenever either the total number of self-attention parameters or MLP parameters scales (up to log factors) linearly with the number of facts. In particular, the transformer can trade off between using the value matrices or the MLP as an associative memory to store the dataset of facts. We complement these expressivity results with an analysis of the gradient flow trajectory of a simplified linear attention model trained on our factual recall task, where we show that the model exhibits sequential learning behavior.

Understanding Factual Recall in Transformers via Associative Memories

TL;DR

This work analyzes how transformers memorize facts by leveraging associative memories, showing that shallow transformers can achieve near-optimal factual recall by combining linear or MLP-based associative memories with attention. The authors prove storage capacity scales linearly with parameter count and demonstrate that a one-layer transformer can achieve 100% recall on a synthetic task when either the self-attention or the MLP parameter counts scale nearly linearly with the dataset size SR (up to logarithmic factors). They additionally study gradient dynamics, revealing a sequential learning trajectory with an intermediate hallucination stage where the model relies on relations before subjects, and provide information-theoretic lower bounds that match their constructions up to log factors. Empirical validations corroborate the theory, showing linear scaling of memory capacity with model size and memory trade-offs between attention and MLP components. Overall, the paper advances understanding of memorization and factual recall mechanisms in transformers and suggests concrete architectural principles for scalable memory in neural networks.

Abstract

Large language models have demonstrated an impressive ability to perform factual recall. Prior work has found that transformers trained on factual recall tasks can store information at a rate proportional to their parameter count. In our work, we show that shallow transformers can use a combination of associative memories to obtain such near optimal storage capacity. We begin by proving that the storage capacities of both linear and MLP associative memories scale linearly with parameter count. We next introduce a synthetic factual recall task, and prove that a transformer with a single layer of self-attention followed by an MLP can obtain 100% accuracy on the task whenever either the total number of self-attention parameters or MLP parameters scales (up to log factors) linearly with the number of facts. In particular, the transformer can trade off between using the value matrices or the MLP as an associative memory to store the dataset of facts. We complement these expressivity results with an analysis of the gradient flow trajectory of a simplified linear attention model trained on our factual recall task, where we show that the model exhibits sequential learning behavior.

Paper Structure

This paper contains 56 sections, 30 theorems, 248 equations, 9 figures.

Key Result

Theorem 1

Assume that $f^*$ is injective. If $d^2 \gtrsim N\mathop{\mathrm{poly}}\nolimits\log N$, then with high probability over the draw of the embeddings, there exists a ${\bm{W}}$ such that This capacity is obtained by the construction ${\bm{W}} = \sum_{x \in [N]} {\bm{u}}_{f^*(x)}{\bm{e}}_x^\top$. Furthermore, if ${\bm{W}}$ is restricted to be a rank $m$ matrix, then such a ${\bm{W}}$ exists when $md

Figures (9)

  • Figure 1: We train linear and MLP associative memories to store the association $f^*(x) = x$. (Left) A linear associative memory requires $d^2 \propto N\log N$ parameters to store $N$ associations. (Right) The MLP associative memory requires $md \propto N \log N$ parameters to store $N$ associations, as predicted by \ref{['thm:mlp-AM']}.
  • Figure 2: A diagram of the synthetic factual recall task.
  • Figure 3: Both the Attention-only and Attention+MLP constructions for the factual recall task.
  • Figure 4: (Left) The number of facts stored scales linearly with the total number of parameters, for a wide range of model sizes. (Right) For a fixed dataset, the model can trade off MLP parameters for attention parameters to obtain 100% accuracy. The heatmap color corresponds to model accuracy.
  • Figure 5: (Left) Loss of the linear attention model with orthogonal embeddings. There is an intermediate hallucination stage where the loss plateaus and the model predicts based on only the relation. (Right) Loss of the softmax attention model with random embeddings. We again observe an intermediate hallucination stage, where the relation-only loss is zero but the total loss is still large.
  • ...and 4 more figures

Theorems & Definitions (53)

  • Theorem 1
  • Theorem 2: Informal
  • Theorem 3: Attention-only, informal
  • Theorem 4: Attention + MLP, informal
  • Theorem 5: Global Convergence
  • Theorem 6: Sequential Learning
  • Theorem 7
  • Corollary 1
  • Theorem 8
  • proof : Proof of \ref{['thm:lin-AM']}
  • ...and 43 more