Table of Contents
Fetching ...

Memory Layers at Scale

Vincent-Pierre Berges, Barlas Oğuz, Daniel Haziza, Wen-tau Yih, Luke Zettlemoyer, Gargi Ghosh

TL;DR

Memory Layers at Scale demonstrates that trainable key-value memory modules can substantially augment transformer models without increasing FLOPs. By introducing scalable mechanisms—product-key lookup, parallel memory, and a shared memory pool—the authors scale memory to up to 128B parameters and show strong gains on factual and knowledge-intensive tasks, often matching denser models with much higher compute. They provide optimized implementations (CUDA kernels, gating, swilu) and a thorough ablation program to inform design decisions. The work suggests memory-augmented architectures offer a viable, scalable alternative to brute-force parameter growth, with practical benefits for factual accuracy and knowledge retention. Hardware optimization and further learning-method research are identified as key avenues for broader deployment and impact.

Abstract

Memory layers use a trainable key-value lookup mechanism to add extra parameters to a model without increasing FLOPs. Conceptually, sparsely activated memory layers complement compute-heavy dense feed-forward layers, providing dedicated capacity to store and retrieve information cheaply. This work takes memory layers beyond proof-of-concept, proving their utility at contemporary scale. On downstream tasks, language models augmented with our improved memory layer outperform dense models with more than twice the computation budget, as well as mixture-of-expert models when matched for both compute and parameters. We find gains are especially pronounced for factual tasks. We provide a fully parallelizable memory layer implementation, demonstrating scaling laws with up to 128B memory parameters, pretrained to 1 trillion tokens, comparing to base models with up to 8B parameters.

Memory Layers at Scale

TL;DR

Memory Layers at Scale demonstrates that trainable key-value memory modules can substantially augment transformer models without increasing FLOPs. By introducing scalable mechanisms—product-key lookup, parallel memory, and a shared memory pool—the authors scale memory to up to 128B parameters and show strong gains on factual and knowledge-intensive tasks, often matching denser models with much higher compute. They provide optimized implementations (CUDA kernels, gating, swilu) and a thorough ablation program to inform design decisions. The work suggests memory-augmented architectures offer a viable, scalable alternative to brute-force parameter growth, with practical benefits for factual accuracy and knowledge retention. Hardware optimization and further learning-method research are identified as key avenues for broader deployment and impact.

Abstract

Memory layers use a trainable key-value lookup mechanism to add extra parameters to a model without increasing FLOPs. Conceptually, sparsely activated memory layers complement compute-heavy dense feed-forward layers, providing dedicated capacity to store and retrieve information cheaply. This work takes memory layers beyond proof-of-concept, proving their utility at contemporary scale. On downstream tasks, language models augmented with our improved memory layer outperform dense models with more than twice the computation budget, as well as mixture-of-expert models when matched for both compute and parameters. We find gains are especially pronounced for factual tasks. We provide a fully parallelizable memory layer implementation, demonstrating scaling laws with up to 128B memory parameters, pretrained to 1 trillion tokens, comparing to base models with up to 8B parameters.

Paper Structure

This paper contains 20 sections, 2 equations, 4 figures, 4 tables.

Figures (4)

  • Figure 1: Scaling the size of the memory for a 1.3 billion parameter base model (zero memory parameters corresponds to a dense model), trained to 1 trillion tokens. On the left, factual QA accuracy (exact match on NaturalQuestions and F1 score on TriviaQA), on the right task NLL (lower is better). Dashed lines show the performance of a 7B model trained on 2 trillion tokens with 10x more FLOPs.
  • Figure 2: Illustration of the parallel EmbeddingBag implementation for a "Memory Group" of two GPUs. Each GPU performs the EmbeddingBag operation on all of the indices of the group, but on half-dimension embeddings it has access to.
  • Figure 3: On the left the regular memory layer. On the right, the Memory+ block, with the added projection, gating and silu non-linearity
  • Figure 4: Accuracy vs. Base Parameters for NaturalQuestions and TriviaQA (Memory+ models use 1 million memory embeddings.)