Table of Contents
Fetching ...

Learning What to Remember: Adaptive Probabilistic Memory Retention for Memory-Efficient Language Models

S M Rafiuddin, Muntaha Nujat Khan

TL;DR

This work tackles the memory bottleneck of Transformer long-context processing by introducing Adaptive Retention, a layer-wise probabilistic token selection mechanism that operates under a global memory budget $M$ without changing core attention. Tokens are retained via Bernoulli gates with a context-aware scoring function and trained through a variational Hard-Concrete relaxation, while inference enforces a top-$M$ rule to bound active tokens. Across six benchmarks, Adaptive Retention nearly matches full-sequence performance at 30-50% token retention, achieving substantial memory savings (~35-45%) and throughput gains (up to $1.8\times$) while remaining architecture-agnostic. The method demonstrates practical long-context efficiency for encoder-based tasks and lays groundwork for future extensions to autoregressive decoding and larger-scale models.

Abstract

Transformer attention scales quadratically with sequence length O(n^2), limiting long-context use. We propose Adaptive Retention, a probabilistic, layer-wise token selection mechanism that learns which representations to keep under a strict global budget M. Retention is modeled with Bernoulli gates trained via a Hard-Concrete/variational relaxation and enforced with a simple top-M rule at inference, making the method differentiable and drop-in for standard encoders. Across classification, extractive QA, and long-document summarization, keeping only 30-50% of tokens preserves >= 95% of full-model performance while cutting peak memory by ~35-45% and improving throughput by up to ~1.8x. This architecture-agnostic approach delivers practical long-context efficiency without modifying base attention or task heads.

Learning What to Remember: Adaptive Probabilistic Memory Retention for Memory-Efficient Language Models

TL;DR

This work tackles the memory bottleneck of Transformer long-context processing by introducing Adaptive Retention, a layer-wise probabilistic token selection mechanism that operates under a global memory budget without changing core attention. Tokens are retained via Bernoulli gates with a context-aware scoring function and trained through a variational Hard-Concrete relaxation, while inference enforces a top- rule to bound active tokens. Across six benchmarks, Adaptive Retention nearly matches full-sequence performance at 30-50% token retention, achieving substantial memory savings (~35-45%) and throughput gains (up to ) while remaining architecture-agnostic. The method demonstrates practical long-context efficiency for encoder-based tasks and lays groundwork for future extensions to autoregressive decoding and larger-scale models.

Abstract

Transformer attention scales quadratically with sequence length O(n^2), limiting long-context use. We propose Adaptive Retention, a probabilistic, layer-wise token selection mechanism that learns which representations to keep under a strict global budget M. Retention is modeled with Bernoulli gates trained via a Hard-Concrete/variational relaxation and enforced with a simple top-M rule at inference, making the method differentiable and drop-in for standard encoders. Across classification, extractive QA, and long-document summarization, keeping only 30-50% of tokens preserves >= 95% of full-model performance while cutting peak memory by ~35-45% and improving throughput by up to ~1.8x. This architecture-agnostic approach delivers practical long-context efficiency without modifying base attention or task heads.

Paper Structure

This paper contains 15 sections, 7 theorems, 44 equations, 2 figures, 6 tables.

Key Result

Proposition 1

Let Let $(\boldsymbol{\theta}_\lambda,\mathbf{p}_\lambda)$ be any minimizer of the Lagrangian and define the slack Then the following slackness bound holds:

Figures (2)

  • Figure 1: Adaptive Retention: layer-wise probabilistic token selection. At each Transformer block, a lightweight gated scorer produces per-token probabilities trained with a Hard–Concrete relaxation. At inference, we keep the top-$M_l$ tokens per layer ($M_l=\lfloor \rho T_l \rfloor$), forwarding only those to the next block. The active sequence length shrinks with depth, yielding cumulative compute and memory savings while leaving base attention unchanged. Symbols: $H^{l}$ token states; $s^{l}$ scores; $p^{l}$ probabilities; $\rho$ target ratio; $M_l$ retained count.
  • Figure 2: Hyperparameter sensitivity of the Adaptive Retention model across six tasks (SST-2, IMDb, ArXiv, QASPER F1, PubMed R-1, CUAD F1). Each panel shows validation performance under sweeps of three parameters: retention temperature $\textcolor{blue}{\beta}$ ($\circ$, blue), stretch $\textcolor{red}{\gamma}$ ($\bullet$, red), and threshold $\textcolor{brown}{\zeta}$ ($\times$, brown), with $\textcolor{black}{\star}$ marking defaults ($\textcolor{blue}{\beta=0.66}$, $\textcolor{red}{\gamma=-0.1}$, $\textcolor{brown}{\zeta=1.1}$).

Theorems & Definitions (14)

  • Proposition 1
  • proof
  • Lemma 1: Unbiased Gradient Estimator
  • proof
  • Lemma 2: Variance Bound
  • proof
  • Proposition 2: Two‐Timescale Convergence
  • proof : Proof Sketch
  • Lemma 3: Duality Gap Bounds Slackness
  • proof
  • ...and 4 more