Table of Contents
Fetching ...

Hardware-aligned Hierarchical Sparse Attention for Efficient Long-term Memory Access

Xiang Hu, Jiaqi Leng, Jun Zhao, Kewei Tu, Wei Wu

TL;DR

This work tackles the challenge of efficient long-context modeling by combining RNN-style processing with long-range random access. It introduces Hierarchical Sparse Attention (HSA), a two-stage mechanism that learns token-to-chunk relevance and aggregates information across selected chunks, enabling accurate, hardware-friendly retrieval beyond pretraining lengths. RAMba extends Mamba by integrating HSA with a hardware-aligned kernel design and CPU-GPU memory offloading to achieve near-constant memory during inference and strong length generalization, including perfect 64M-context passkey retrieval. Empirically, RAMba outperforms full-attention and prior sparse-attention baselines on long-range language modeling and downstream tasks, while maintaining favorable efficiency and scalability. The results suggest RAMba/HSA as a practical foundation for language models with persistent memory and robust long-context capabilities.

Abstract

A key advantage of Recurrent Neural Networks (RNNs) over Transformers is their linear computational and space complexity enables faster training and inference for long sequences. However, RNNs are fundamentally unable to randomly access historical context, and simply integrating attention mechanisms may undermine their efficiency advantages. To overcome this limitation, we propose Hierarchical Sparse Attention (HSA), a novel attention mechanism that enhances RNNs with long-range random access flexibility while preserving their merits in efficiency and length generalization. HSA divides inputs into chunks, selects the top-$k$ chunks and hierarchically aggregates information. The core innovation lies in learning token-to-chunk relevance based on fine-grained token-level information inside each chunk. This approach enhances the precision of chunk selection across both in-domain and out-of-domain context lengths. To make HSA efficient, we further introduce a hardware-aligned kernel design. By combining HSA with Mamba, we introduce RAMba, which achieves perfect accuracy in passkey retrieval across 64 million contexts despite pre-training on only 4K-length contexts, and significant improvements on various downstream tasks, with nearly constant memory footprint. These results show RAMba's huge potential in long-context modeling.

Hardware-aligned Hierarchical Sparse Attention for Efficient Long-term Memory Access

TL;DR

This work tackles the challenge of efficient long-context modeling by combining RNN-style processing with long-range random access. It introduces Hierarchical Sparse Attention (HSA), a two-stage mechanism that learns token-to-chunk relevance and aggregates information across selected chunks, enabling accurate, hardware-friendly retrieval beyond pretraining lengths. RAMba extends Mamba by integrating HSA with a hardware-aligned kernel design and CPU-GPU memory offloading to achieve near-constant memory during inference and strong length generalization, including perfect 64M-context passkey retrieval. Empirically, RAMba outperforms full-attention and prior sparse-attention baselines on long-range language modeling and downstream tasks, while maintaining favorable efficiency and scalability. The results suggest RAMba/HSA as a practical foundation for language models with persistent memory and robust long-context capabilities.

Abstract

A key advantage of Recurrent Neural Networks (RNNs) over Transformers is their linear computational and space complexity enables faster training and inference for long sequences. However, RNNs are fundamentally unable to randomly access historical context, and simply integrating attention mechanisms may undermine their efficiency advantages. To overcome this limitation, we propose Hierarchical Sparse Attention (HSA), a novel attention mechanism that enhances RNNs with long-range random access flexibility while preserving their merits in efficiency and length generalization. HSA divides inputs into chunks, selects the top- chunks and hierarchically aggregates information. The core innovation lies in learning token-to-chunk relevance based on fine-grained token-level information inside each chunk. This approach enhances the precision of chunk selection across both in-domain and out-of-domain context lengths. To make HSA efficient, we further introduce a hardware-aligned kernel design. By combining HSA with Mamba, we introduce RAMba, which achieves perfect accuracy in passkey retrieval across 64 million contexts despite pre-training on only 4K-length contexts, and significant improvements on various downstream tasks, with nearly constant memory footprint. These results show RAMba's huge potential in long-context modeling.

Paper Structure

This paper contains 41 sections, 6 equations, 8 figures, 8 tables, 3 algorithms.

Figures (8)

  • Figure 1: $\mathbf{K}_i,\mathbf{V}_i$ are the $i$-th chunk's key and value, with $\mathbf{\bar{K}}_i$ the mean pooling of $\mathbf{K}_i$. In (A), the chunk selection scores $\mathbf{Q}_t^\top\mathbf{\bar{K}}_i$ are learned from token-to-token interactions (chunk-unaware). In HSA (B), $\mathbf{Q}_t^\top\mathbf{\bar{K}}_i$ are guided by the feedback from the entire chunk (chunk-aware), with $\mathbf{O}_{t,i}$ the chunk-level information obtained from the $i$-th chunk by the $t$-th token.
  • Figure 2: (a) Model architecture for RAMba. (b) Kernel design for HSA.
  • Figure 3: Passkey retrieval results.
  • Figure 4: Comparison of attention computation time: 3 attention layers per group. (The lower the better)
  • Figure 5: How unnormalized scores mislead the chunk selection.
  • ...and 3 more figures