Learning What to Remember: Adaptive Probabilistic Memory Retention for Memory-Efficient Language Models
S M Rafiuddin, Muntaha Nujat Khan
TL;DR
This work tackles the memory bottleneck of Transformer long-context processing by introducing Adaptive Retention, a layer-wise probabilistic token selection mechanism that operates under a global memory budget $M$ without changing core attention. Tokens are retained via Bernoulli gates with a context-aware scoring function and trained through a variational Hard-Concrete relaxation, while inference enforces a top-$M$ rule to bound active tokens. Across six benchmarks, Adaptive Retention nearly matches full-sequence performance at 30-50% token retention, achieving substantial memory savings (~35-45%) and throughput gains (up to $1.8\times$) while remaining architecture-agnostic. The method demonstrates practical long-context efficiency for encoder-based tasks and lays groundwork for future extensions to autoregressive decoding and larger-scale models.
Abstract
Transformer attention scales quadratically with sequence length O(n^2), limiting long-context use. We propose Adaptive Retention, a probabilistic, layer-wise token selection mechanism that learns which representations to keep under a strict global budget M. Retention is modeled with Bernoulli gates trained via a Hard-Concrete/variational relaxation and enforced with a simple top-M rule at inference, making the method differentiable and drop-in for standard encoders. Across classification, extractive QA, and long-document summarization, keeping only 30-50% of tokens preserves >= 95% of full-model performance while cutting peak memory by ~35-45% and improving throughput by up to ~1.8x. This architecture-agnostic approach delivers practical long-context efficiency without modifying base attention or task heads.
