Retrieval-Aware Distillation for Transformer-SSM Hybrids
Aviv Bick, Eric P. Xing, Albert Gu
TL;DR
The paper tackles the efficiency gap between Transformers and state-space models (SSMs) in tasks requiring in-context retrieval. It introduces retrieval-aware distillation, which identifies a small set of retrieval-critical attention heads (G&A heads) and preserves only those during distillation, replacing the rest with recurrent SSM components, all trained via the MOHAWK framework. Empirically, retaining as few as 10 attention heads (about 2% of heads) recovers over 95% of the teacher’s performance on retrieval-heavy tasks, while allowing an 8x reduction in SSM state dimensionality and 5–6x memory savings overall. The approach yields leaner, memory-efficient hybrids without sacrificing retrieval capabilities, highlighting that retrieval tasks are localized to a small subset of heads and can be offloaded to attention to enable compact models. Practical impact includes faster inference and lower memory footprints for long-sequence processing with competitive performance on retrieval-centric benchmarks.
Abstract
State-space models (SSMs) offer efficient sequence modeling but lag behind Transformers on benchmarks that require in-context retrieval. Prior work links this gap to a small set of attention heads, termed Gather-and-Aggregate (G&A), which SSMs struggle to reproduce. We propose *retrieval-aware distillation*, which converts a pretrained Transformer into a hybrid student by preserving only these retrieval-critical heads and distilling the rest into recurrent heads. We identify the essential heads via ablation on a synthetic retrieval task, producing a hybrid with sparse, non-uniform attention placement. We show that preserving **just 2% of attention heads recovers over 95% of teacher performance on retrieval-heavy tasks** (10 heads in a 1B model), requiring far fewer heads than hybrids that retain at least 25%. We further find that large recurrent states often compensate for missing retrieval: once retrieval is handled by these heads, the SSM backbone can be simplified with limited loss, even with an $8\times$ reduction in state dimension. By reducing both the attention cache and the SSM state, the resulting hybrid is $5$--$6\times$ more memory-efficient than comparable hybrids, closing the Transformer--SSM gap at a fraction of the memory cost.
