Table of Contents
Fetching ...

Associative Recurrent Memory Transformer

Ivan Rodkin, Yuri Kuratov, Aydar Bulatov, Mikhail Burtsev

TL;DR

ARMT addresses long-context processing by combining segment-level recurrence with an associative memory mechanism, enabling constant-time updates per new segment. It extends the RMT framework with layerwise associative memory, achieving superior performance on associative retrieval and long-context benchmarks such as Babilong, including up to 50M tokens and ~79.9% QA1 accuracy. The paper introduces a memory-capacity estimator and uses curriculum learning to train on progressively longer contexts, demonstrating robustness in memory recall tasks and notable length generalization. Limitations include training challenges, slower runtime on shorter sequences, and LM performance that requires further optimization and scaling to realize full potential in larger models.

Abstract

This paper addresses the challenge of creating a neural architecture for very long sequences that requires constant time for processing new information at each time step. Our approach, Associative Recurrent Memory Transformer (ARMT), is based on transformer self-attention for local context and segment-level recurrence for storage of task specific information distributed over a long context. We demonstrate that ARMT outperfors existing alternatives in associative retrieval tasks and sets a new performance record in the recent BABILong multi-task long-context benchmark by answering single-fact questions over 50 million tokens with an accuracy of 79.9%. The source code for training and evaluation is available on github.

Associative Recurrent Memory Transformer

TL;DR

ARMT addresses long-context processing by combining segment-level recurrence with an associative memory mechanism, enabling constant-time updates per new segment. It extends the RMT framework with layerwise associative memory, achieving superior performance on associative retrieval and long-context benchmarks such as Babilong, including up to 50M tokens and ~79.9% QA1 accuracy. The paper introduces a memory-capacity estimator and uses curriculum learning to train on progressively longer contexts, demonstrating robustness in memory recall tasks and notable length generalization. Limitations include training challenges, slower runtime on shorter sequences, and LM performance that requires further optimization and scaling to realize full potential in larger models.

Abstract

This paper addresses the challenge of creating a neural architecture for very long sequences that requires constant time for processing new information at each time step. Our approach, Associative Recurrent Memory Transformer (ARMT), is based on transformer self-attention for local context and segment-level recurrence for storage of task specific information distributed over a long context. We demonstrate that ARMT outperfors existing alternatives in associative retrieval tasks and sets a new performance record in the recent BABILong multi-task long-context benchmark by answering single-fact questions over 50 million tokens with an accuracy of 79.9%. The source code for training and evaluation is available on github.
Paper Structure (17 sections, 10 equations, 7 figures, 2 tables)

This paper contains 17 sections, 10 equations, 7 figures, 2 tables.

Figures (7)

  • Figure 1: ARMT augments the transformer's layers with associative memory. (a) RMT architecture. (b) ARMT adds associative memory processing to each layer. (c) Associative memory is updated with layerwise memory representations.
  • Figure 2: ARMT demonstrates strong performance on associative memory tasks.(a) The estimated number of pairs, stored in memory after processing the context with key-value pairs. (b) ARMT is more accurate at operations in memory. Being trained only on 50 key-value pairs from Associative Retrieval Rewrite task, ARMT performs accurate even on 500 memory updates. So the observed generalization factor is 10 (500 pairs / 50 pairs). All data are averaged over 3 runs except RMT and PRMT with 2 runs.
  • Figure 3: ARMT sets a record in long-context processing with reasonable performance on 50 million tokens. Accuracy of models on different lengths from Babilong benchmark: panels a-e represent QA1-5 tasks.
  • Figure 4: (a) $\gamma$-correction cures the quasi-linear attention memory. Without correction, the quasi-linear attention with delta-rule struggles to extrapolate on unseen amounts of memory updates. (b) Parallel memory doesn't solve the capacity issue. This means that the associative memory plays an important role in increasing the capacity of the memory.
  • Figure 5: Parallel recurrent memory transformer. In contrast to RMT, in PRMT memory tokens are passed to the next segment in each layer.
  • ...and 2 more figures