Table of Contents
Fetching ...

Differential Mamba

Nadav Schneider, Itamar Zimerman, Eliya Nachmani

TL;DR

This paper investigates applying differential design to the Mamba architecture to mitigate over-allocation of attention to irrelevant context, thereby improving robustness and long-context retrieval. It introduces Diff-Mamba, a differential mechanism that subtracts two Mamba blocks (with normalization and a learnable weight) to produce a data-controlled linear operator with reduced noise. Across language modeling benchmarks, ablations, and mechanistic interpretability analyses, Diff-Mamba yields lower perplexity and better retrieval performance, especially at greater depths and longer contexts, while demonstrating improved signal-to-noise characteristics in intermediate representations. The work highlights the broader potential of differential design beyond Transformers, provides empirical guidance on design choices, and releases code publicly, though it notes the need for theoretical justification and scalability considerations for broader domains.

Abstract

Sequence models like Transformers and RNNs often overallocate attention to irrelevant context, leading to noisy intermediate representations. This degrades LLM capabilities by promoting hallucinations, weakening long-range and retrieval abilities, and reducing robustness. Recent work has shown that differential design can mitigate this issue in Transformers, improving their effectiveness across various applications. In this paper, we explore whether these techniques, originally developed for Transformers, can be applied to Mamba, a recent architecture based on selective state-space layers that achieves Transformer-level performance with greater efficiency. We show that a naive adaptation of differential design to Mamba is insufficient and requires careful architectural modifications. To address this, we introduce a novel differential mechanism for Mamba, empirically validated on language modeling benchmarks, demonstrating improved retrieval capabilities and superior performance over vanilla Mamba. Finally, we conduct extensive ablation studies and empirical analyses to justify our design choices and provide evidence that our approach effectively mitigates the overallocation problem in Mamba-based models. Our code is publicly available: https://github.com/NadavSc/Diff-Mamba

Differential Mamba

TL;DR

This paper investigates applying differential design to the Mamba architecture to mitigate over-allocation of attention to irrelevant context, thereby improving robustness and long-context retrieval. It introduces Diff-Mamba, a differential mechanism that subtracts two Mamba blocks (with normalization and a learnable weight) to produce a data-controlled linear operator with reduced noise. Across language modeling benchmarks, ablations, and mechanistic interpretability analyses, Diff-Mamba yields lower perplexity and better retrieval performance, especially at greater depths and longer contexts, while demonstrating improved signal-to-noise characteristics in intermediate representations. The work highlights the broader potential of differential design beyond Transformers, provides empirical guidance on design choices, and releases code publicly, though it notes the need for theoretical justification and scalability considerations for broader domains.

Abstract

Sequence models like Transformers and RNNs often overallocate attention to irrelevant context, leading to noisy intermediate representations. This degrades LLM capabilities by promoting hallucinations, weakening long-range and retrieval abilities, and reducing robustness. Recent work has shown that differential design can mitigate this issue in Transformers, improving their effectiveness across various applications. In this paper, we explore whether these techniques, originally developed for Transformers, can be applied to Mamba, a recent architecture based on selective state-space layers that achieves Transformer-level performance with greater efficiency. We show that a naive adaptation of differential design to Mamba is insufficient and requires careful architectural modifications. To address this, we introduce a novel differential mechanism for Mamba, empirically validated on language modeling benchmarks, demonstrating improved retrieval capabilities and superior performance over vanilla Mamba. Finally, we conduct extensive ablation studies and empirical analyses to justify our design choices and provide evidence that our approach effectively mitigates the overallocation problem in Mamba-based models. Our code is publicly available: https://github.com/NadavSc/Diff-Mamba

Paper Structure

This paper contains 35 sections, 22 equations, 9 figures, 8 tables.

Figures (9)

  • Figure 1: Comparative illustration of our variants Diff-Mamba and Diff-S6 versus the original Mamba architecture, where $\otimes$ is elementwise multiplication, $\sigma$ is the SILU activation, Linear and Conv1D are standard linear projection and 1-dimensional convolution layers, and N stands for normalizations.
  • Figure 2: Comparison of test curves through the training for Mamba and Diff-Mamba. The top row shows results for 6-layer models, and the bottom row for 12-layer models. Columns correspond to datasets: Enwik8 (left), Text-8 (center), and WikiText-103 (right).
  • Figure 4: Retrieval Abilities: Comparison of Diff-Mamba and Mamba models across five retrieval tasks from BABILong. X-axis represents the context length, and y-axis corresponds to the task index. Each cell displays the ratio in which one model outperforms the other. Green cells indicate wins by Diff-Mamba, while red cells indicate wins by Mamba.
  • Figure 5: Performance comparison on synthetic token manipulation tasks. We evaluate Mamba and Diff-Mamba architectures across six synthetic capability benchmarks. Diff-Mamba consistently outperforms Mamba, with notable improvements in Fuzzy ICR of 80.0% and Compression with 10.6%. Stars indicate the best model per task.
  • Figure 6: Measuring Signal-to-Noise Ratio: The y-axis represents the probability of predicting the desired needle token, where lower values indicate higher noise. The x-axis denotes various layers within the model where intermediate noise is measured. Results show the average needle probabilities in each layer on 1k examples in BABILong questions of 1k-2k tokens.
  • ...and 4 more figures