Table of Contents
Fetching ...

M+: Extending MemoryLLM with Scalable Long-Term Memory

Yu Wang, Dmitry Krotov, Yuanzhe Hu, Yifan Gao, Wangchunshu Zhou, Julian McAuley, Dan Gutfreund, Rogerio Feris, Zexue He

TL;DR

M+ extends MemoryLLM by integrating a scalable long-term memory module with a co-trained retriever, enabling retrieval and use of past information across very long contexts while keeping GPU memory budgets in check via CPU offload. The approach uses per-layer long-term memory tokens and a dual-LoRA design to separate write/read phases, supplemented by a compact retriever (f_q, f_k) trained with a contrastive objective. Across long-context QA, event reasoning, and knowledge-retention benchmarks, M+ consistently outperforms MemoryLLM and strong baselines, extending effective memory retention from under 20k to over 160k tokens with similar or lower GPU memory usage. A staged data curriculum (short-context to long-document to long-term memory) and analysis show that the gains arise from the memory component and the learned retrieval, without harming performance on shorter contexts. The work offers a practical pathway to scalable long-context reasoning in LLMs and highlights trade-offs in latency and memory management that future work can address.

Abstract

Equipping large language models (LLMs) with latent-space memory has attracted increasing attention as they can extend the context window of existing language models. However, retaining information from the distant past remains a challenge. For example, MemoryLLM (Wang et al., 2024a), as a representative work with latent-space memory, compresses past information into hidden states across all layers, forming a memory pool of 1B parameters. While effective for sequence lengths up to 16k tokens, it struggles to retain knowledge beyond 20k tokens. In this work, we address this limitation by introducing M+, a memory-augmented model based on MemoryLLM that significantly enhances long-term information retention. M+ integrates a long-term memory mechanism with a co-trained retriever, dynamically retrieving relevant information during text generation. We evaluate M+ on diverse benchmarks, including long-context understanding and knowledge retention tasks. Experimental results show that M+ significantly outperforms MemoryLLM and recent strong baselines, extending knowledge retention from under 20k to over 160k tokens with similar GPU memory overhead. We open-source our code at https://github.com/wangyu-ustc/MemoryLLM

M+: Extending MemoryLLM with Scalable Long-Term Memory

TL;DR

M+ extends MemoryLLM by integrating a scalable long-term memory module with a co-trained retriever, enabling retrieval and use of past information across very long contexts while keeping GPU memory budgets in check via CPU offload. The approach uses per-layer long-term memory tokens and a dual-LoRA design to separate write/read phases, supplemented by a compact retriever (f_q, f_k) trained with a contrastive objective. Across long-context QA, event reasoning, and knowledge-retention benchmarks, M+ consistently outperforms MemoryLLM and strong baselines, extending effective memory retention from under 20k to over 160k tokens with similar or lower GPU memory usage. A staged data curriculum (short-context to long-document to long-term memory) and analysis show that the gains arise from the memory component and the learned retrieval, without harming performance on shorter contexts. The work offers a practical pathway to scalable long-context reasoning in LLMs and highlights trade-offs in latency and memory management that future work can address.

Abstract

Equipping large language models (LLMs) with latent-space memory has attracted increasing attention as they can extend the context window of existing language models. However, retaining information from the distant past remains a challenge. For example, MemoryLLM (Wang et al., 2024a), as a representative work with latent-space memory, compresses past information into hidden states across all layers, forming a memory pool of 1B parameters. While effective for sequence lengths up to 16k tokens, it struggles to retain knowledge beyond 20k tokens. In this work, we address this limitation by introducing M+, a memory-augmented model based on MemoryLLM that significantly enhances long-term information retention. M+ integrates a long-term memory mechanism with a co-trained retriever, dynamically retrieving relevant information during text generation. We evaluate M+ on diverse benchmarks, including long-context understanding and knowledge retention tasks. Experimental results show that M+ significantly outperforms MemoryLLM and recent strong baselines, extending knowledge retention from under 20k to over 160k tokens with similar GPU memory overhead. We open-source our code at https://github.com/wangyu-ustc/MemoryLLM

Paper Structure

This paper contains 54 sections, 6 equations, 9 figures, 4 tables.

Figures (9)

  • Figure 1: The left side shows the Update and Generation Process of MemoryLLM memoryllm. We process the chunk with $\phi_l$ to obtain new $K$ tokens during the update process, which is perceived by $\phi$ using cross-attention during the generation process. The right side shows the Update and Generation Process of M+. For layer $l$, during Update, the old memory pool $\theta_l$ is split into two parts: $K$ dropped tokens and $N-K$ remaining tokens. The dropped tokens are stored in the long-term memory $\Theta_l$ while the remaining tokens and new $K$ tokens are combined to obtain the new memory pool $\theta_l'$. Then during generation, we use our co-trained retriever to retrieve tokens from $\Theta_l$, which is fed into the transformer layer $\phi_l$ along with the short-term memory $\theta_l$ and the query hidden states. The major difference between MemoryLLM and M+ is the introduction of Long-Term Memory $\Theta_l$.
  • Figure 2: Overall Performance Comparison Longbook Question Answering. Best viewed in colors.
  • Figure 3: Knowledge Retention Results on SQuAD.
  • Figure 4: Validation loss comparison on a held-out subset from Slim-Pajama, consisting of 1,000 examples. The three models, MemoryLLM-8B, MemoryLLM-8B-Long, and M+, are obtained after Stages 1, 2, and 3, respectively (Section \ref{['ssub:data_curriculum']}).
  • Figure 5: Ablation Study on SQuAD dataset.
  • ...and 4 more figures