Table of Contents
Fetching ...

HMT: Hierarchical Memory Transformer for Efficient Long Context Language Processing

Zifan He, Yingqi Cao, Zongyue Qin, Neha Prakriya, Yizhou Sun, Jason Cong

TL;DR

HMT introduces a brain-inspired hierarchical memory transformer that augments decoder-only LMs with a memory retrieval mechanism and segment-wise recurrence to enable truly long-context processing. By encoding segment representations, retrieving relevant past memory via cross-attention, and combining sensory, short-term, and long-term memory, HMT provides memory-efficient long-context capabilities across diverse backbones. Empirical results across language modeling and QA show consistent gains with modest parameter overhead and substantially reduced inference memory relative to long-context baselines, underscoring practical benefits for resource-constrained settings. This plug-and-play framework facilitates scalable lifelong language tasks by enabling efficient handling of unbounded histories without architectural overhauls of backbone models.

Abstract

Transformer-based large language models (LLM) have been widely used in language processing applications. However, due to the memory constraints of the devices, most of them restrict the context window. Even though recurrent models in previous works can memorize past tokens to enable unlimited context and maintain effectiveness, they have ``flat'' memory architectures. Such architectures have limitations in selecting and filtering information. Since humans are good at learning and self-adjustment, we believe that imitating brain memory hierarchy is beneficial for model memorization. Thus, we propose the Hierarchical Memory Transformer (HMT), a novel framework that facilitates a model's long-context processing ability by imitating human memorization behavior. Leveraging memory-augmented segment-level recurrence, we organize the memory hierarchy by preserving tokens from early input segments, passing memory embeddings along the sequence, and recalling relevant information from history. Evaluating general language modeling, question-answering tasks, and the summarization task, we show that HMT consistently improves the long-context processing ability of existing models. Furthermore, HMT achieves a comparable or superior generation quality to long-context LLMs with $2 \sim 57\times$ fewer parameters and $2.5 \sim 116\times$ less inference memory, significantly outperforming previous memory-augmented models. Code on Github: https://github.com/OswaldHe/HMT-pytorch.

HMT: Hierarchical Memory Transformer for Efficient Long Context Language Processing

TL;DR

HMT introduces a brain-inspired hierarchical memory transformer that augments decoder-only LMs with a memory retrieval mechanism and segment-wise recurrence to enable truly long-context processing. By encoding segment representations, retrieving relevant past memory via cross-attention, and combining sensory, short-term, and long-term memory, HMT provides memory-efficient long-context capabilities across diverse backbones. Empirical results across language modeling and QA show consistent gains with modest parameter overhead and substantially reduced inference memory relative to long-context baselines, underscoring practical benefits for resource-constrained settings. This plug-and-play framework facilitates scalable lifelong language tasks by enabling efficient handling of unbounded histories without architectural overhauls of backbone models.

Abstract

Transformer-based large language models (LLM) have been widely used in language processing applications. However, due to the memory constraints of the devices, most of them restrict the context window. Even though recurrent models in previous works can memorize past tokens to enable unlimited context and maintain effectiveness, they have ``flat'' memory architectures. Such architectures have limitations in selecting and filtering information. Since humans are good at learning and self-adjustment, we believe that imitating brain memory hierarchy is beneficial for model memorization. Thus, we propose the Hierarchical Memory Transformer (HMT), a novel framework that facilitates a model's long-context processing ability by imitating human memorization behavior. Leveraging memory-augmented segment-level recurrence, we organize the memory hierarchy by preserving tokens from early input segments, passing memory embeddings along the sequence, and recalling relevant information from history. Evaluating general language modeling, question-answering tasks, and the summarization task, we show that HMT consistently improves the long-context processing ability of existing models. Furthermore, HMT achieves a comparable or superior generation quality to long-context LLMs with fewer parameters and less inference memory, significantly outperforming previous memory-augmented models. Code on Github: https://github.com/OswaldHe/HMT-pytorch.
Paper Structure (29 sections, 10 equations, 17 figures, 9 tables)

This paper contains 29 sections, 10 equations, 17 figures, 9 tables.

Figures (17)

  • Figure 1: Overall workflow of HMT. For a segment, (1) HMT will first perform representation encoding, utilizing the segment summarization prompt embedding ($T$) to summarize part of the segment. (2) The generated segment summary embedding ($S_n$) is used with the cached memory embeddings for memory search with cross attention. The output is a memorization prompt embedding ($P_n$) which contains information relevant to the current segment. (3) The memorization prompt embedding and the last $k$ embeddings from the previous segment will augment the segment. (4) The backbone model (BBM) will process the augmented segment and generate hidden embeddings for logits ($H_n^{out}$) and the memory embedding ($M_n$), which will be pushed into the long-term memory.
  • Figure 2: Test Perplexity of HMT, RMT, and three baseline models (OPT 2.7B, RWKV 3B, OpenLlamaV2 3B) with the Wikitext-103 dataset. HMT outperforms RMT by 13.0% for OPT and 10.8% for OpenLlamaV2. For RWKV, HMT can even boost the effectiveness by 16.5%, while RMT worsens the effectiveness.
  • Figure 3: Test Perplexity of HMT, RMT, and three baseline models (OPT 2.7B, RWKV 3B, OpenLlamaV2 3B), evaluated over the PG-19 dataset. HMT outperforms RMT by 3.98% for OPT and 6.85% for OpenLlamaV2. For RWKV, HMT can improve the effectiveness by 9.96%.
  • Figure 4: Test Perplexity of HMT, RMT, and baseline model for Qwen 2.5 14B on PG-19 dataset. HMT boosts the effectiveness of the baseline model by 10.0%, while RMT worsens its effectiveness.
  • Figure 5: Long answer quality of RMT and HMT applied on Llama-2 7B, evaluated over PubMedQA dataset. HMT is 8.98% more effective than RMT.
  • ...and 12 more figures