Table of Contents
Fetching ...

GradMem: Learning to Write Context into Memory with Test-Time Gradient Descent

Yuri Kuratov, Matvey Kairov, Aydar Bulatov, Ivan Rodkin, Mikhail Burtsev

Abstract

Many large language model applications require conditioning on long contexts. Transformers typically support this by storing a large per-layer KV-cache of past activations, which incurs substantial memory overhead. A desirable alternative is ompressive memory: read a context once, store it in a compact state, and answer many queries from that state. We study this in a context removal setting, where the model must generate an answer without access to the original context at inference time. We introduce GradMem, which writes context into memory via per-sample test-time optimization. Given a context, GradMem performs a few steps of gradient descent on a small set of prefix memory tokens while keeping model weights frozen. GradMem explicitly optimizes a model-level self-supervised context reconstruction loss, resulting in a loss-driven write operation with iterative error correction, unlike forward-only methods. On associative key--value retrieval, GradMem outperforms forward-only memory writers with the same memory size, and additional gradient steps scale capacity much more effectively than repeated forward writes. We further show that GradMem transfers beyond synthetic benchmarks: with pretrained language models, it attains competitive results on natural language tasks including bAbI and SQuAD variants, relying only on information encoded in memory.

GradMem: Learning to Write Context into Memory with Test-Time Gradient Descent

Abstract

Many large language model applications require conditioning on long contexts. Transformers typically support this by storing a large per-layer KV-cache of past activations, which incurs substantial memory overhead. A desirable alternative is ompressive memory: read a context once, store it in a compact state, and answer many queries from that state. We study this in a context removal setting, where the model must generate an answer without access to the original context at inference time. We introduce GradMem, which writes context into memory via per-sample test-time optimization. Given a context, GradMem performs a few steps of gradient descent on a small set of prefix memory tokens while keeping model weights frozen. GradMem explicitly optimizes a model-level self-supervised context reconstruction loss, resulting in a loss-driven write operation with iterative error correction, unlike forward-only methods. On associative key--value retrieval, GradMem outperforms forward-only memory writers with the same memory size, and additional gradient steps scale capacity much more effectively than repeated forward writes. We further show that GradMem transfers beyond synthetic benchmarks: with pretrained language models, it attains competitive results on natural language tasks including bAbI and SQuAD variants, relying only on information encoded in memory.
Paper Structure (16 sections, 13 equations, 8 figures, 4 tables)

This paper contains 16 sections, 13 equations, 8 figures, 4 tables.

Figures (8)

  • Figure 1: GradMem learns to write context into memory via per-sample test-time optimization. Given a context, GradMem performs a few test-time gradient updates on memory state to minimize a self-supervised reconstruction loss (WRITE). The memory initialization is meta-learned so that useful context representations can be written with only a few gradient steps. At inference, the model answers queries using only the optimized memory and the query (READ), without access to the original context.
  • Figure 2: GradMem overview.(a) Each task sample is represented as context $C$, query $Q$, and target $Y$. A context encoder $E_{\theta}$ compresses $C$ into a fixed-size memory $\hat{\mathcal{M}}$ in a WRITE phase, and the model predicts $Y$ from $[\hat{\mathcal{M}}; Q]$ in a READ phase, without access to $C$. (b) Meta-learning view: a shared initialization $\mathcal{M}_0$ and model parameters $\theta$ are learned across training examples (outer loop), while at test time each context $C_i$ adapts its own memory $\hat{\mathcal{M}}(C_i)$ via a few gradient steps (dotted trajectory). (c) Test-time gradient descent on memory. Starting from the meta-learned initialization $\mathcal{M}_0$, GradMem updates the per-sample memory state during WRITE with $K$ steps of gradient descent on the context reconstruction loss $\mathcal{L}_{\text{write}}(\mathcal{M}; C)$. At READ, the model predicts the task target using only $[\hat{\mathcal{M}}; Q]$. During training, $\theta$ and $\mathcal{M}_0$ are optimized by backpropagating through the WRITE outer loop, so the model learns to use gradient descent as an operation that writes useful information about $C$ into memory; at inference, only the per-sample memory state is updated.
  • Figure 3: Gradient-based memory updates (GradMem) outperform forward-only updates at the same memory size. With a memory state of 8 vectors, GradMem retrieves any of 16 key--value pairs with 95% accuracy, whereas a forward-only update rule stores only 8 pairs with high accuracy. More gradient steps increases the capacity of the same 8-vector memory to 96 pairs at 88% retrieval accuracy. Transformer with KV-cache as non-compressive memory serves as an upper bound.
  • Figure 4: More gradient steps at test-time lead to better performance without fine-tuning.(a) Results for $K_{\text{train}}$ values of GradMem setups that could not achieve 99% exact match during fine-tuning. Red dashed lines denote the beginning of extrapolation. (b) Downstream task quality (Exact Match) correlates with WRITE objective in inner loop (reconstruction). Notably, GradMem reconstructs values better than keys, despite the WRITE objective treating them equally. Results on 96-pair KV-retrieval.
  • Figure 5: Comparison of speed and memory consumption for attention backwards-over-backwards in GradMem. These results were obtained using an A100 GPU, with GPT-2 as the base model for GradMem with 8 memory tokens, a query size of 24, 1 inner SGD step and a batch size of 16.
  • ...and 3 more figures