Table of Contents
Fetching ...

Lattice: Learning to Efficiently Compress the Memory

Mahdi Karami, Razvan Pascanu, Vahab Mirrokni

TL;DR

Lattice tackles the memory and compute bottlenecks of attention by compressing an unbounded key-value cache into a fixed set of memory slots and updating them via an online gradient step that writes only information orthogonal to each slot's current state. By normalizing memory slots and using a decoding/encoding framework with latent codes, it yields a stable, interpretable orthogonal state recurrence with sub-quadratic complexity. The approach is grounded in online optimization and connects to dictionary learning and Riemannian optimization, delivering improved perplexity on long-context language modeling tasks and robust ablations that highlight normalization as a key factor. These results suggest Lattice as a scalable alternative or complement to Transformers for long-context sequence modeling and potential for test-time adaptation and efficient fine-tuning.

Abstract

Attention mechanisms have revolutionized sequence learning but suffer from quadratic computational complexity. This paper introduces Lattice, a novel recurrent neural network (RNN) mechanism that leverages the inherent low-rank structure of K-V matrices to efficiently compress the cache into a fixed number of memory slots, achieving sub-quadratic complexity. We formulate this compression as an online optimization problem and derive a dynamic memory update rule based on a single gradient descent step. The resulting recurrence features a state- and input-dependent gating mechanism, offering an interpretable memory update process. The core innovation is the orthogonal update: each memory slot is updated exclusively with information orthogonal to its current state hence incorporation of only novel, non-redundant data, which minimizes the interference with previously stored information. The experimental results show that Lattice achieves the best perplexity compared to all baselines across diverse context lengths, with performance improvement becoming more pronounced as the context length increases.

Lattice: Learning to Efficiently Compress the Memory

TL;DR

Lattice tackles the memory and compute bottlenecks of attention by compressing an unbounded key-value cache into a fixed set of memory slots and updating them via an online gradient step that writes only information orthogonal to each slot's current state. By normalizing memory slots and using a decoding/encoding framework with latent codes, it yields a stable, interpretable orthogonal state recurrence with sub-quadratic complexity. The approach is grounded in online optimization and connects to dictionary learning and Riemannian optimization, delivering improved perplexity on long-context language modeling tasks and robust ablations that highlight normalization as a key factor. These results suggest Lattice as a scalable alternative or complement to Transformers for long-context sequence modeling and potential for test-time adaptation and efficient fine-tuning.

Abstract

Attention mechanisms have revolutionized sequence learning but suffer from quadratic computational complexity. This paper introduces Lattice, a novel recurrent neural network (RNN) mechanism that leverages the inherent low-rank structure of K-V matrices to efficiently compress the cache into a fixed number of memory slots, achieving sub-quadratic complexity. We formulate this compression as an online optimization problem and derive a dynamic memory update rule based on a single gradient descent step. The resulting recurrence features a state- and input-dependent gating mechanism, offering an interpretable memory update process. The core innovation is the orthogonal update: each memory slot is updated exclusively with information orthogonal to its current state hence incorporation of only novel, non-redundant data, which minimizes the interference with previously stored information. The experimental results show that Lattice achieves the best perplexity compared to all baselines across diverse context lengths, with performance improvement becoming more pronounced as the context length increases.

Paper Structure

This paper contains 27 sections, 2 theorems, 35 equations, 4 figures, 4 tables.

Key Result

Proposition 3.1

Let $\mathcal{C} = \{ \mathbf{s} \in \mathbb{R}^d \mid \|\mathbf{s}\| = 1 \}$ be the unit sphere. Then, the projected gradient update of the form ${\boldsymbol{s}}_{i,t} = \mathcal{P}_{\mathcal{C}}({\boldsymbol{s}}_{i,t-1} + \Delta {\boldsymbol{s}}_{i,t})$ (as in equation eqn:normalized_recurrence),

Figures (4)

  • Figure 1: A geometric visualizing of the proposed update rule. (a) A single current state vector, ${\boldsymbol{s}}_{t-1} = \mathbf{S}_{t-1}[: \, , i]$, an incoming token representation, ${\boldsymbol{h}}_t$, and its component orthogonal to the current state, ${\boldsymbol{h}}_t^{\perp {{\boldsymbol{s}}}_{t-1}}$. (b) Comparison of the updated state according to the proposed update rule (${\boldsymbol{s}}_{t}= {\boldsymbol{s}}_{t-1} + \alpha_{i,t} \, {\boldsymbol{h}}_t^{\perp {{\boldsymbol{s}}}_{t-1}}$) and the updated state resulting from the superposition recurrence update of the standard linear attention ($\hat{{\boldsymbol{s}}}_{t}= {\boldsymbol{s}}_{t-1} + \alpha_{i,t} \, {\boldsymbol{h}}_t$, shown with a dashed arrow). For simplicity, a unit writing intensity ($\alpha_{i,t}=1$) is assumed in both recurrent update rules. (c) Visualization of the relationships between ${d \times m}$ state matrices over time in state-dependent compression, depicted as interconnections of nodes in a 3D lattice. Each memory slot (state vector) is represented by a unique color.
  • Figure 2: (Left) Block diagram of the language model. (Right) The Lattice block. Following the architecture used in Mamba gu2023mamba, each sequence mixing block is composed of a pair of short $\texttt{Conv1D}$ for the pair $\{q, k \}$ and the Lattice is followed by a $\texttt{GeLU}$ post-gate.
  • Figure 3: Model perplexity as a function of context length for models of size 110M parameters. (Left) displays results for the Books dataset vs context length $\{512, 1024, 2048, 4096, 8192, 16384\}$ ; (Right) shows results for The Pile dataset vs context length $\{2048, 8192\}$. Note that pre-training Transformers from scratch often performs poorly on very long contexts (e.g., 16k); the common approach is finetuning from shorter-context models touvron2023llama. Therefore, the baseline pre-trained Transformer results shown here are limited to context lengths $T \le 8k$.
  • Figure 4: An illustration of the proposed update rule. (a) Example of a single memory slot state, ${\boldsymbol{s}}_t$, an incoming token representation, ${\boldsymbol{h}}_t$, and its component orthogonal to the current state, ${\boldsymbol{h}}_t^{\perp {{\boldsymbol{s}}}_{t-1}}$. (b) The updated state according to the proposed update rule, ${\boldsymbol{s}}_{t}= {\boldsymbol{s}}_{t-1} + \alpha_{i,t} \, {\boldsymbol{h}}_t^{\perp {{\boldsymbol{s}}}_{t-1}}$ contrasted with the updated state resulting from the superposition recurrence update used in standard linear attention: $\hat{{\boldsymbol{s}}}_{t}= {\boldsymbol{s}}_{t-1} + \alpha_{i,t} \, {\boldsymbol{h}}_t$, (dashed arrow). A unit writing intensity ($\alpha_{i,t}=1$) is assumed for simplicity in both recurrent update rules.

Theorems & Definitions (5)

  • Proposition 3.1: Equivalence to Gradient Descent on Riemannian Manifold
  • Remark 3.2
  • Remark 3.3: Delta Rule
  • Remark 3.4: Parallel and Hardware Efficient Implementation
  • Proposition B.1: Proof of Proposition \ref{['thm:Riemannian']}