Learning to Forget Attention: Memory Consolidation for Adaptive Compute Reduction
Ibne Farabi Shihab, Sanjeda Akter, Anuj Sharma
TL;DR
This work identifies pervasive attention redundancy in large sequence models and shows that standard training provides no signal to reduce attention usage. It introduces CRAM, a three-tier memory architecture with a consolidation-based router that shifts retrieval from episodic attention to semantic recall as a semantic memory approximates episodic results, achieving a dramatic $37.8\times$ reduction in attention compute and reaching $1.6\%$ attention with 100% retrieval accuracy on SRCD. The approach yields strong transfer of consolidated patterns to unseen tasks (48–52% reduction) and exhibits biologically aligned memory dynamics, with a power-law consolidation curve ($\gamma \approx 0.43$) comparable to human memory data. Overall, CRAM demonstrates a principled, adaptive computation paradigm where memory consolidation enables efficient, human-like memory systems for neural sequence models, with practical impact for efficient deployment and broader cognitive plausibility.
Abstract
Hybrid architectures combining state-space models with attention have achieved strong efficiency-quality tradeoffs, yet existing approaches either apply attention uniformly or learn static sparse patterns. This misses a key opportunity: \emph{attention demand should decrease over time as recurring patterns become familiar}. We present a surprising finding from analyzing GPT-2 models: \textbf{88\%} of attention operations retrieve information already predictable from the model's hidden state, and this redundancy does \emph{not} decrease during training. Motivated by this observation, we introduce \textbf{\ours{}} (\textbf{C}onsolidation-based \textbf{R}outing for \textbf{A}daptive \textbf{M}emory), a biologically inspired memory consolidation mechanism that gradually distills episodic retrievals into parametric semantic memory. Unlike prior sparse attention methods, \ours{} exhibits \emph{decreasing attention utilization} over training, achieving a \textbf{37.8$\times$} reduction through a sharp phase transition at approximately 3K steps. We prove that this capability is \emph{impossible} without consolidation: any static routing scheme requires $Ω(f \cdot n)$ attention for tasks with recurring patterns of frequency $f$. On our proposed SRCD benchmark, \ours{} achieves \textbf{100\% retrieval accuracy} at 1.6\% attention compute (vs.\ 68\% for baselines), and consolidated patterns transfer to unseen tasks with \textbf{48--52\%} attention reduction without retraining. Remarkably, the learned consolidation dynamics quantitatively match human episodic-to-semantic memory transition curves from cognitive psychology ($γ= 0.43$ vs.\ $γ_{\text{human}} \approx 0.4$--$0.5$). Code and benchmarks are available at [anonymized].
