HMT: Hierarchical Memory Transformer for Efficient Long Context Language Processing
Zifan He, Yingqi Cao, Zongyue Qin, Neha Prakriya, Yizhou Sun, Jason Cong
TL;DR
HMT introduces a brain-inspired hierarchical memory transformer that augments decoder-only LMs with a memory retrieval mechanism and segment-wise recurrence to enable truly long-context processing. By encoding segment representations, retrieving relevant past memory via cross-attention, and combining sensory, short-term, and long-term memory, HMT provides memory-efficient long-context capabilities across diverse backbones. Empirical results across language modeling and QA show consistent gains with modest parameter overhead and substantially reduced inference memory relative to long-context baselines, underscoring practical benefits for resource-constrained settings. This plug-and-play framework facilitates scalable lifelong language tasks by enabling efficient handling of unbounded histories without architectural overhauls of backbone models.
Abstract
Transformer-based large language models (LLM) have been widely used in language processing applications. However, due to the memory constraints of the devices, most of them restrict the context window. Even though recurrent models in previous works can memorize past tokens to enable unlimited context and maintain effectiveness, they have ``flat'' memory architectures. Such architectures have limitations in selecting and filtering information. Since humans are good at learning and self-adjustment, we believe that imitating brain memory hierarchy is beneficial for model memorization. Thus, we propose the Hierarchical Memory Transformer (HMT), a novel framework that facilitates a model's long-context processing ability by imitating human memorization behavior. Leveraging memory-augmented segment-level recurrence, we organize the memory hierarchy by preserving tokens from early input segments, passing memory embeddings along the sequence, and recalling relevant information from history. Evaluating general language modeling, question-answering tasks, and the summarization task, we show that HMT consistently improves the long-context processing ability of existing models. Furthermore, HMT achieves a comparable or superior generation quality to long-context LLMs with $2 \sim 57\times$ fewer parameters and $2.5 \sim 116\times$ less inference memory, significantly outperforming previous memory-augmented models. Code on Github: https://github.com/OswaldHe/HMT-pytorch.
