Attention and Compression is all you need for Controllably Efficient Language Models
Jatin Prakash, Aahlad Puli, Rajesh Ranganath
TL;DR
Transformer attention typically incurs $O(N^2)$ compute in sequence length $N$, limiting long-context modeling. CAT introduces a simple, end-to-end trainable compression step that summarizes past chunks and a dense, autoregressive decoder that attends to these compressed representations, enabling test-time control of the quality–compute trade-off via chunk size $C$. A single adaptive CAT model (trained across multiple $C$ values) matches or surpasses dense transformers in speed and memory while outperforming efficient baselines on in-context recall, long-context understanding, and QA tasks, across varied budgets. The approach scales with model size and can be implemented in pure PyTorch without custom kernels, offering practical deployment advantages for memory-bound generation workloads and flexible budget-aware inference.
Abstract
The quadratic cost of attention in transformers motivated the development of efficient approaches: namely sparse and sliding window attention, convolutions and linear attention. Although these approaches result in impressive reductions in compute and memory, they often trade-off with quality, specifically in-context recall performance. Moreover, apriori fixing this quality-compute tradeoff means being suboptimal from the get-go: some downstream applications require more memory for in-context recall, while others require lower latency and memory. Further, these approaches rely on heuristic choices that artificially restrict attention, or require handcrafted and complex recurrent state update rules, or they must be carefully composed with attention at specific layers to form a hybrid architecture that complicates the design process, especially at scale. To address above issues, we propose Compress & Attend Transformer (CAT), a conceptually simple architecture employing two simple ingredients only: dense attention and compression. CAT decodes chunks of tokens by attending to compressed chunks of the sequence so far. Compression results in decoding from a reduced sequence length that yields compute and memory savings, while choosing a particular chunk size trades-off quality for efficiency. Moreover, CAT can be trained with multiple chunk sizes at once, unlocking control of quality-compute trade-offs directly at test-time without any retraining, all in a single adaptive architecture. In exhaustive evaluations on common language modeling tasks, in-context recall, and long-context understanding, a single adaptive CAT model outperforms existing efficient baselines, including hybrid architectures, across different compute-memory budgets. Further, a single CAT matches dense transformer in language modeling across model scales while being 1.4-3x faster and requiring 2-9x lower total memory usage.
