Table of Contents
Fetching ...

Learning to Forget Attention: Memory Consolidation for Adaptive Compute Reduction

Ibne Farabi Shihab, Sanjeda Akter, Anuj Sharma

TL;DR

This work identifies pervasive attention redundancy in large sequence models and shows that standard training provides no signal to reduce attention usage. It introduces CRAM, a three-tier memory architecture with a consolidation-based router that shifts retrieval from episodic attention to semantic recall as a semantic memory approximates episodic results, achieving a dramatic $37.8\times$ reduction in attention compute and reaching $1.6\%$ attention with 100% retrieval accuracy on SRCD. The approach yields strong transfer of consolidated patterns to unseen tasks (48–52% reduction) and exhibits biologically aligned memory dynamics, with a power-law consolidation curve ($\gamma \approx 0.43$) comparable to human memory data. Overall, CRAM demonstrates a principled, adaptive computation paradigm where memory consolidation enables efficient, human-like memory systems for neural sequence models, with practical impact for efficient deployment and broader cognitive plausibility.

Abstract

Hybrid architectures combining state-space models with attention have achieved strong efficiency-quality tradeoffs, yet existing approaches either apply attention uniformly or learn static sparse patterns. This misses a key opportunity: \emph{attention demand should decrease over time as recurring patterns become familiar}. We present a surprising finding from analyzing GPT-2 models: \textbf{88\%} of attention operations retrieve information already predictable from the model's hidden state, and this redundancy does \emph{not} decrease during training. Motivated by this observation, we introduce \textbf{\ours{}} (\textbf{C}onsolidation-based \textbf{R}outing for \textbf{A}daptive \textbf{M}emory), a biologically inspired memory consolidation mechanism that gradually distills episodic retrievals into parametric semantic memory. Unlike prior sparse attention methods, \ours{} exhibits \emph{decreasing attention utilization} over training, achieving a \textbf{37.8$\times$} reduction through a sharp phase transition at approximately 3K steps. We prove that this capability is \emph{impossible} without consolidation: any static routing scheme requires $Ω(f \cdot n)$ attention for tasks with recurring patterns of frequency $f$. On our proposed SRCD benchmark, \ours{} achieves \textbf{100\% retrieval accuracy} at 1.6\% attention compute (vs.\ 68\% for baselines), and consolidated patterns transfer to unseen tasks with \textbf{48--52\%} attention reduction without retraining. Remarkably, the learned consolidation dynamics quantitatively match human episodic-to-semantic memory transition curves from cognitive psychology ($γ= 0.43$ vs.\ $γ_{\text{human}} \approx 0.4$--$0.5$). Code and benchmarks are available at [anonymized].

Learning to Forget Attention: Memory Consolidation for Adaptive Compute Reduction

TL;DR

This work identifies pervasive attention redundancy in large sequence models and shows that standard training provides no signal to reduce attention usage. It introduces CRAM, a three-tier memory architecture with a consolidation-based router that shifts retrieval from episodic attention to semantic recall as a semantic memory approximates episodic results, achieving a dramatic reduction in attention compute and reaching attention with 100% retrieval accuracy on SRCD. The approach yields strong transfer of consolidated patterns to unseen tasks (48–52% reduction) and exhibits biologically aligned memory dynamics, with a power-law consolidation curve () comparable to human memory data. Overall, CRAM demonstrates a principled, adaptive computation paradigm where memory consolidation enables efficient, human-like memory systems for neural sequence models, with practical impact for efficient deployment and broader cognitive plausibility.

Abstract

Hybrid architectures combining state-space models with attention have achieved strong efficiency-quality tradeoffs, yet existing approaches either apply attention uniformly or learn static sparse patterns. This misses a key opportunity: \emph{attention demand should decrease over time as recurring patterns become familiar}. We present a surprising finding from analyzing GPT-2 models: \textbf{88\%} of attention operations retrieve information already predictable from the model's hidden state, and this redundancy does \emph{not} decrease during training. Motivated by this observation, we introduce \textbf{\ours{}} (\textbf{C}onsolidation-based \textbf{R}outing for \textbf{A}daptive \textbf{M}emory), a biologically inspired memory consolidation mechanism that gradually distills episodic retrievals into parametric semantic memory. Unlike prior sparse attention methods, \ours{} exhibits \emph{decreasing attention utilization} over training, achieving a \textbf{37.8} reduction through a sharp phase transition at approximately 3K steps. We prove that this capability is \emph{impossible} without consolidation: any static routing scheme requires attention for tasks with recurring patterns of frequency . On our proposed SRCD benchmark, \ours{} achieves \textbf{100\% retrieval accuracy} at 1.6\% attention compute (vs.\ 68\% for baselines), and consolidated patterns transfer to unseen tasks with \textbf{48--52\%} attention reduction without retraining. Remarkably, the learned consolidation dynamics quantitatively match human episodic-to-semantic memory transition curves from cognitive psychology ( vs.\ --). Code and benchmarks are available at [anonymized].
Paper Structure (36 sections, 4 theorems, 18 equations, 4 figures, 7 tables)

This paper contains 36 sections, 4 theorems, 18 equations, 4 figures, 7 tables.

Key Result

Theorem 1

Consider a task where a fraction $f$ of positions require correct retrieval from a set of $K$ recurring patterns, each appearing with frequency $f/K$, and correct retrieval is necessary for task success. Then any static routing scheme achieving task accuracy $\geq 1 - \epsilon$ must have expected at where $n$ is the sequence length.

Figures (4)

  • Figure 1: The CRAM architecture. Each layer routes tokens through a consolidation-aware router to one of three memory tiers: (i) a continuous-time working memory for local dynamics, (ii) an episodic memory buffer accessed via attention for novel retrieval, and (iii) a semantic memory adapter for consolidated patterns. The consolidation loss (red, dashed) trains semantic memory to approximate episodic retrieval; the quality signal $q_t$ (purple, dashed) feeds back to the router. As $q_t$ increases during training, the router shifts from episodic ($O(n)$) to semantic ($O(1)$) routing, producing a 37.8$\times$ reduction in attention compute.
  • Figure 2: Phase transition in consolidation. CRAM's attention usage remains moderate until approximately 3K steps, then drops sharply as semantic memory begins accurately approximating episodic retrieval. This emergence phenomenon mirrors grokking in neural networks. Prior methods show no such transition because their compute allocation is static by design.
  • Figure 3: Consolidation exhibits sharp phase transition. Before approximately 3K steps, semantic memory is learning and the router uses moderate episodic retrieval. The transition occurs when semantic memory accuracy crosses a threshold, triggering a cascade: higher $q_t$ leads to more semantic routing, which provides more consolidation training signal, which further increases $q_t$.
  • Figure 4: Learned consolidation matches human memory dynamics. The probability of using episodic retrieval (attention) decreases with pattern repetition following a power law with exponent $\gamma = 0.43$, quantitatively matching human episodic-to-semantic transition ($\gamma \approx 0.4$--$0.5$).

Theorems & Definitions (8)

  • Definition 1: Attention Redundancy
  • Definition 2: Static Routing Scheme
  • Theorem 1: Lower Bound for Static Routing
  • proof
  • Corollary 1: Consolidation Enables Sub-Linear Attention
  • Theorem 2: Consolidation Convergence
  • Theorem 3: Attention Reduction Guarantee
  • proof : Full proof of Theorem \ref{['thm:impossibility']}