Table of Contents
Fetching ...

ReMamba: Equip Mamba with Effective Long-Sequence Modeling

Danlong Yuan, Jiahao Liu, Bei Li, Huishuai Zhang, Jingang Wang, Xunliang Cai, Dongyan Zhao

TL;DR

ReMamba addresses the long-context degradation of Mamba, a linear-time state-space model, by introducing a two-stage re-forward mechanism that selectively compresses and adapts distant information. Stage1 compactly selects and replaces salient final-layer hidden states, Stage2 integrates these compressed representations into Mamba’s state updates with a trainable control to mitigate information loss. Empirical results on LongBench and LEval show substantial gains over the Mamba baseline and competitive performance with same-size transformers, with transfer gains observed on Mamba2 as well. The approach achieves these improvements with only a modest inference overhead, offering a practical path to enhancing long-context capabilities of memory-efficient models while outlining limitations and avenues for future refinement in state-space updates.

Abstract

While the Mamba architecture demonstrates superior inference efficiency and competitive performance on short-context natural language processing (NLP) tasks, empirical evidence suggests its capacity to comprehend long contexts is limited compared to transformer-based models. In this study, we investigate the long-context efficiency issues of the Mamba models and propose ReMamba, which enhances Mamba's ability to comprehend long contexts. ReMamba incorporates selective compression and adaptation techniques within a two-stage re-forward process, incurring minimal additional inference costs overhead. Experimental results on the LongBench and L-Eval benchmarks demonstrate ReMamba's efficacy, improving over the baselines by 3.2 and 1.6 points, respectively, and attaining performance almost on par with same-size transformer models.

ReMamba: Equip Mamba with Effective Long-Sequence Modeling

TL;DR

ReMamba addresses the long-context degradation of Mamba, a linear-time state-space model, by introducing a two-stage re-forward mechanism that selectively compresses and adapts distant information. Stage1 compactly selects and replaces salient final-layer hidden states, Stage2 integrates these compressed representations into Mamba’s state updates with a trainable control to mitigate information loss. Empirical results on LongBench and LEval show substantial gains over the Mamba baseline and competitive performance with same-size transformers, with transfer gains observed on Mamba2 as well. The approach achieves these improvements with only a modest inference overhead, offering a practical path to enhancing long-context capabilities of memory-efficient models while outlining limitations and avenues for future refinement in state-space updates.

Abstract

While the Mamba architecture demonstrates superior inference efficiency and competitive performance on short-context natural language processing (NLP) tasks, empirical evidence suggests its capacity to comprehend long contexts is limited compared to transformer-based models. In this study, we investigate the long-context efficiency issues of the Mamba models and propose ReMamba, which enhances Mamba's ability to comprehend long contexts. ReMamba incorporates selective compression and adaptation techniques within a two-stage re-forward process, incurring minimal additional inference costs overhead. Experimental results on the LongBench and L-Eval benchmarks demonstrate ReMamba's efficacy, improving over the baselines by 3.2 and 1.6 points, respectively, and attaining performance almost on par with same-size transformer models.
Paper Structure (27 sections, 6 equations, 8 figures, 5 tables)

This paper contains 27 sections, 6 equations, 8 figures, 5 tables.

Figures (8)

  • Figure 1: A comparison of pretrained Mamba models and Transformers of equivalent size across speed, short-context, and long-context performance metrics. Speed is measured under conditions of 6k input tokens and 1k output tokens. "short scores" represents the average accuracy across six tasks (HellaSwag, PIQA, Arc-E, Arc-C, WinoGrande, OpenbookQA) evaluated within the LM evaluation harness eval-harness. "long scores" corresponds to the average scores on the LongBench-E benchmark bai2024longbenchbilingualmultitaskbenchmark. Notably, all LongBench evaluations employ a maximum token length of 2k to align with the model's training configuration.
  • Figure 2: ReMamba architecture. We just show one layer and leave out the $A$, $B$ and discrete method here. For Stage 2, only those value vectors selected need to go through selective adaption. Normal token embeddings just flow as usual. We select top-$K$ (here is top-2) hidden states in the last layer according to their importance scores calculated with the last hidden state $h_L$. And we incorporate the scores into the gradient utilizing the selective mechanism in Mamba.
  • Figure 3: Ablation study about average scores on LongBench varying max length from 2k to 9k. "Mamba(SFT)" is the finetuned Mamba. "fix_select" is the Fix Selection. "random_select" is the Random Selection. "multiplicative_select" is the Multiplicative Selection.
  • Figure 4: Average scores on LongBench varying max length from 2k to 9k. The "Pre" means pretrained model while "SFT" means finetuned model. The performance of Llama-3b (SFT) and Llama-3b (Pre) is for reference, using the max length of 6k.
  • Figure 5: Average scores on L-Eval varying max length from 2k to 9k. The performance of Llama-3b (SFT) and Llama-3b (Pre) is for reference, using the max length of 6k.
  • ...and 3 more figures