Table of Contents
Fetching ...

Pretraining with hierarchical memories: separating long-tail and common knowledge

Hadi Pouransari, David Grangier, C Thomas, Michael Kirchhof, Oncel Tuzel

TL;DR

The paper tackles inefficiencies in storing world knowledge solely in model parameters by introducing hierarchical parametric memories that separate long-tail knowledge from common knowledge. It couples a small anchor transformer with a large, hierarchically organized memory bank retrieved by a clustering-based memory retriever, enabling context-dependent memory augmentation during pretraining and inference. FFN-Memories emerge as the most effective memory type, with deeper, larger memory banks yielding robust gains across CK and SK tasks, achieving performance comparable to larger models while using fewer runtime parameters. The approach demonstrates practical benefits for on-device deployment, privacy, and post-hoc memory augmentation, and it shows complementary potential when combined with retrieval-augmented generation (RAG).

Abstract

The impressive performance gains of modern language models currently rely on scaling parameters: larger models store more world knowledge and reason better. Yet compressing all world knowledge into parameters is unnecessary, as only a fraction is used per prompt, and impractical for edge devices with limited inference-time memory and compute. We address this shortcoming by a memory-augmented architecture and a pretraining strategy aligned with existing hardware paradigms. We introduce small language models that access large hierarchical parametric memory banks encoding world knowledge. During pretraining and inference, we fetch a small, context-dependent memory block and add it to the model. Our pretraining learns to store long-tail world knowledge in the memory parameters, while the small language model acts as an anchor capturing common knowledge and general reasoning abilities. Through trillion-token-scale experiments, we show significant gains: a 160M-parameters model augmented with an 18M-parameters memory fetched from a 4.6B memory bank obtains comparable performance to a regular model with more than 2x the parameters. Through extensive experiments, we study the optimal type and size of parametric memories in transformers, scaling them to over 21B parameters. We find that our proposed hierarchical feed-forward memories work robustly across transformer architectures, whether added during pretraining or post-hoc.

Pretraining with hierarchical memories: separating long-tail and common knowledge

TL;DR

The paper tackles inefficiencies in storing world knowledge solely in model parameters by introducing hierarchical parametric memories that separate long-tail knowledge from common knowledge. It couples a small anchor transformer with a large, hierarchically organized memory bank retrieved by a clustering-based memory retriever, enabling context-dependent memory augmentation during pretraining and inference. FFN-Memories emerge as the most effective memory type, with deeper, larger memory banks yielding robust gains across CK and SK tasks, achieving performance comparable to larger models while using fewer runtime parameters. The approach demonstrates practical benefits for on-device deployment, privacy, and post-hoc memory augmentation, and it shows complementary potential when combined with retrieval-augmented generation (RAG).

Abstract

The impressive performance gains of modern language models currently rely on scaling parameters: larger models store more world knowledge and reason better. Yet compressing all world knowledge into parameters is unnecessary, as only a fraction is used per prompt, and impractical for edge devices with limited inference-time memory and compute. We address this shortcoming by a memory-augmented architecture and a pretraining strategy aligned with existing hardware paradigms. We introduce small language models that access large hierarchical parametric memory banks encoding world knowledge. During pretraining and inference, we fetch a small, context-dependent memory block and add it to the model. Our pretraining learns to store long-tail world knowledge in the memory parameters, while the small language model acts as an anchor capturing common knowledge and general reasoning abilities. Through trillion-token-scale experiments, we show significant gains: a 160M-parameters model augmented with an 18M-parameters memory fetched from a 4.6B memory bank obtains comparable performance to a regular model with more than 2x the parameters. Through extensive experiments, we study the optimal type and size of parametric memories in transformers, scaling them to over 21B parameters. We find that our proposed hierarchical feed-forward memories work robustly across transformer architectures, whether added during pretraining or post-hoc.

Paper Structure

This paper contains 24 sections, 2 equations, 20 figures, 8 tables.

Figures (20)

  • Figure 1: Left: Schematic of pretraining-with-memories: some parameters are always used (anchor parameters), others are fetched per input document (memory parameters). Middle: Accuracy improvement over baseline when $\simeq 10\%$ of parameters are allocated as memories for a knowledge-intensive task (predicting the atomic numbers of elements), using models with 160M, 410M, and 1.4B parameters, corresponding to rows A2, B2, and C2 in \ref{['tab:cotraining']}. Right: Elements sorted by their frequency of appearance in the DCLM-Baseline dataset (5 buckets, each with $\simeq 24$ elements). With the proposed pretraining-with-memories, we observe significant improvements, especially on long-tail data. While the baseline 1.4B model has only 17% accuracy on the least frequent element bucket, augmenting it with only 10% memory parameters increases the accuracy to 83%.
  • Figure 2: Proposed architecture: For a given context $x$ (such as a question text), the memory retriever module selects relevant parameters from a large set of memory bank parameters. These memory parameters are organized hierarchically based on the hierarchical clustering of the pretraining data. The anchor model, together with the retrieved memories, then responds to the question.
  • Figure 3: Effect of memory type on Specific-Knowledge benchmarks (a) and Wikipedia perplexity (b). Effect of memory level on performance as a function of fetched memory size (c) and bank size (d).
  • Figure 4: (a) Avg-SK accuracy for different hierarchical memories, demonstrating performance gain with larger bank size and fetched memory size. (b) Wiki-En perplexity for different fetched memory–to–anchor model size ratios, with the optimal point at 1:10. The purple curve shows the perplexity of anchor models without memory. The green curves show the perplexity of models with memory, with different shades of green corresponding to the progress of memory training.
  • Figure 4: OpenLM-160M
  • ...and 15 more figures