Table of Contents
Fetching ...

LoLA: Low-Rank Linear Attention With Sparse Caching

Luke McDermott, Robert W. Heath, Rahul Parhi

TL;DR

LoLA addresses the memory bottleneck of transformer-style in-context learning by augmenting linear attention with a sparse-caching mechanism that preserves constant memory. It partitions past KV pairs into a local sliding window, a sparse global cache for difficult-to-remember pairs, and a recurrent hidden-state reservoir, guided by a self-recall error metric that identifies memory collisions. The approach yields dramatic gains in long-context associative recall (e.g., from 0.6% to 97.4% on needle-in-a-haystack tasks at 4K context) with a small cache, and it improves zero-shot commonsense reasoning on 1B–8B subquadratic models. This training-free inference strategy expands the practical reach of subquadratic LLMs to lifelong in-context learning scenarios, offering a flexible, cache-tunable trade-off between memory footprint and recall performance.

Abstract

The per-token cost of transformer inference scales with context length, preventing its application to lifelong in-context learning. Linear attention is an efficient alternative that maintains a constant memory footprint, even on infinite context lengths. While this is a potential candidate for lifelong learning, it falls short in memory capacity. In this paper, we propose LoLA, a training-free augmentation to linear attention that boosts associative recall. LoLA distributes past key-value pairs from context into three memory systems: (i) recent pairs in a local sliding window cache; (ii) difficult-to-memorize pairs in a sparse, global cache; and (iii) generic pairs in the recurrent hidden state of linear attention. We show through ablations that our self-recall error metric is crucial to efficiently manage long-term associative memories. On pass-key retrieval tasks, LoLA improves the base model's performance from 0.6% to 97.4% accuracy. This is achieved with a 4.6x smaller cache than Llama-3.1 8B on 4K context length. LoLA also outperforms other 1B and 8B parameter subquadratic models on zero-shot commonsense reasoning tasks.

LoLA: Low-Rank Linear Attention With Sparse Caching

TL;DR

LoLA addresses the memory bottleneck of transformer-style in-context learning by augmenting linear attention with a sparse-caching mechanism that preserves constant memory. It partitions past KV pairs into a local sliding window, a sparse global cache for difficult-to-remember pairs, and a recurrent hidden-state reservoir, guided by a self-recall error metric that identifies memory collisions. The approach yields dramatic gains in long-context associative recall (e.g., from 0.6% to 97.4% on needle-in-a-haystack tasks at 4K context) with a small cache, and it improves zero-shot commonsense reasoning on 1B–8B subquadratic models. This training-free inference strategy expands the practical reach of subquadratic LLMs to lifelong in-context learning scenarios, offering a flexible, cache-tunable trade-off between memory footprint and recall performance.

Abstract

The per-token cost of transformer inference scales with context length, preventing its application to lifelong in-context learning. Linear attention is an efficient alternative that maintains a constant memory footprint, even on infinite context lengths. While this is a potential candidate for lifelong learning, it falls short in memory capacity. In this paper, we propose LoLA, a training-free augmentation to linear attention that boosts associative recall. LoLA distributes past key-value pairs from context into three memory systems: (i) recent pairs in a local sliding window cache; (ii) difficult-to-memorize pairs in a sparse, global cache; and (iii) generic pairs in the recurrent hidden state of linear attention. We show through ablations that our self-recall error metric is crucial to efficiently manage long-term associative memories. On pass-key retrieval tasks, LoLA improves the base model's performance from 0.6% to 97.4% accuracy. This is achieved with a 4.6x smaller cache than Llama-3.1 8B on 4K context length. LoLA also outperforms other 1B and 8B parameter subquadratic models on zero-shot commonsense reasoning tasks.

Paper Structure

This paper contains 31 sections, 23 equations, 10 figures, 8 tables.

Figures (10)

  • Figure 1: LoLA stores past KV pairs in three forms memory for each attention head.
  • Figure 2: Illustration of where each KV pair is stored at every time step for each method.
  • Figure 3: Visualizing memory collisions by measuring SRE for stored KV pairs.
  • Figure 4: Measuring Time-to-First-Token for various sliding window and sparse cache sizes. This measurement is averaged across 100 trials and assumes data is already loaded into VRAM.
  • Figure 5: Measuring Peak VRAM usage for various sliding window and sparse cache sizes. This measurement includes the base model weights, the data sequence, and online activations such as KV caches.
  • ...and 5 more figures