Table of Contents
Fetching ...

Retrieval-Aware Distillation for Transformer-SSM Hybrids

Aviv Bick, Eric P. Xing, Albert Gu

TL;DR

The paper tackles the efficiency gap between Transformers and state-space models (SSMs) in tasks requiring in-context retrieval. It introduces retrieval-aware distillation, which identifies a small set of retrieval-critical attention heads (G&A heads) and preserves only those during distillation, replacing the rest with recurrent SSM components, all trained via the MOHAWK framework. Empirically, retaining as few as 10 attention heads (about 2% of heads) recovers over 95% of the teacher’s performance on retrieval-heavy tasks, while allowing an 8x reduction in SSM state dimensionality and 5–6x memory savings overall. The approach yields leaner, memory-efficient hybrids without sacrificing retrieval capabilities, highlighting that retrieval tasks are localized to a small subset of heads and can be offloaded to attention to enable compact models. Practical impact includes faster inference and lower memory footprints for long-sequence processing with competitive performance on retrieval-centric benchmarks.

Abstract

State-space models (SSMs) offer efficient sequence modeling but lag behind Transformers on benchmarks that require in-context retrieval. Prior work links this gap to a small set of attention heads, termed Gather-and-Aggregate (G&A), which SSMs struggle to reproduce. We propose *retrieval-aware distillation*, which converts a pretrained Transformer into a hybrid student by preserving only these retrieval-critical heads and distilling the rest into recurrent heads. We identify the essential heads via ablation on a synthetic retrieval task, producing a hybrid with sparse, non-uniform attention placement. We show that preserving **just 2% of attention heads recovers over 95% of teacher performance on retrieval-heavy tasks** (10 heads in a 1B model), requiring far fewer heads than hybrids that retain at least 25%. We further find that large recurrent states often compensate for missing retrieval: once retrieval is handled by these heads, the SSM backbone can be simplified with limited loss, even with an $8\times$ reduction in state dimension. By reducing both the attention cache and the SSM state, the resulting hybrid is $5$--$6\times$ more memory-efficient than comparable hybrids, closing the Transformer--SSM gap at a fraction of the memory cost.

Retrieval-Aware Distillation for Transformer-SSM Hybrids

TL;DR

The paper tackles the efficiency gap between Transformers and state-space models (SSMs) in tasks requiring in-context retrieval. It introduces retrieval-aware distillation, which identifies a small set of retrieval-critical attention heads (G&A heads) and preserves only those during distillation, replacing the rest with recurrent SSM components, all trained via the MOHAWK framework. Empirically, retaining as few as 10 attention heads (about 2% of heads) recovers over 95% of the teacher’s performance on retrieval-heavy tasks, while allowing an 8x reduction in SSM state dimensionality and 5–6x memory savings overall. The approach yields leaner, memory-efficient hybrids without sacrificing retrieval capabilities, highlighting that retrieval tasks are localized to a small subset of heads and can be offloaded to attention to enable compact models. Practical impact includes faster inference and lower memory footprints for long-sequence processing with competitive performance on retrieval-centric benchmarks.

Abstract

State-space models (SSMs) offer efficient sequence modeling but lag behind Transformers on benchmarks that require in-context retrieval. Prior work links this gap to a small set of attention heads, termed Gather-and-Aggregate (G&A), which SSMs struggle to reproduce. We propose *retrieval-aware distillation*, which converts a pretrained Transformer into a hybrid student by preserving only these retrieval-critical heads and distilling the rest into recurrent heads. We identify the essential heads via ablation on a synthetic retrieval task, producing a hybrid with sparse, non-uniform attention placement. We show that preserving **just 2% of attention heads recovers over 95% of teacher performance on retrieval-heavy tasks** (10 heads in a 1B model), requiring far fewer heads than hybrids that retain at least 25%. We further find that large recurrent states often compensate for missing retrieval: once retrieval is handled by these heads, the SSM backbone can be simplified with limited loss, even with an reduction in state dimension. By reducing both the attention cache and the SSM state, the resulting hybrid is -- more memory-efficient than comparable hybrids, closing the Transformer--SSM gap at a fraction of the memory cost.
Paper Structure (33 sections, 11 equations, 3 figures, 7 tables)

This paper contains 33 sections, 11 equations, 3 figures, 7 tables.

Figures (3)

  • Figure 1: Retrieval-Aware Attention Placement During Distillation. We add a retrieval-guided step before standard distillation: (1) ablate each attention head in the pretrained Transformer and measure the accuracy drop on a synthetic KV-retrieval probe to obtain a retrieval-importance score; (2) retain only heads above a threshold and replace the rest with recurrent heads; (3) distill into a hybrid student. We use KV-retrieval only for head ranking; it identifies the same G&A heads that drive retrieval-intensive performance gather_and_aggregate. Unlike heuristic hybrids, this yields non-uniform attention placement that preserves performance with far fewer heads and reduces attention’s KV-cache bandwidth cost dao2022flashattention.
  • Figure 2: Transformer-SSM Integration. We selectively replace attention heads with SSMs. To ensure distributional alignment, we normalize the remaining concatenated head states to match the mean and variance of the SSM output before the final projection (Static LayerNorm in the figure). This parameter-free step stabilizes hybrid integration. Code in \ref{['app:adapter']}.
  • Figure 3: Perplexity during distillation reveals retrieval as the primary bottleneck. Perplexity is tracked over the first 5,000 training steps for hybrid models with varying numbers of retained G&A heads. Most of the perplexity reduction occurs with the first 10-20 heads---those ranked highest by retrieval importance---with diminishing returns beyond that. This confirms that retrieval is the dominant factor in early language modeling improvements and supports our core claim: targeted retention of retrieval heads is both necessary and sufficient for closing the performance gap.