Table of Contents
Fetching ...

Attention and Compression is all you need for Controllably Efficient Language Models

Jatin Prakash, Aahlad Puli, Rajesh Ranganath

TL;DR

Transformer attention typically incurs $O(N^2)$ compute in sequence length $N$, limiting long-context modeling. CAT introduces a simple, end-to-end trainable compression step that summarizes past chunks and a dense, autoregressive decoder that attends to these compressed representations, enabling test-time control of the quality–compute trade-off via chunk size $C$. A single adaptive CAT model (trained across multiple $C$ values) matches or surpasses dense transformers in speed and memory while outperforming efficient baselines on in-context recall, long-context understanding, and QA tasks, across varied budgets. The approach scales with model size and can be implemented in pure PyTorch without custom kernels, offering practical deployment advantages for memory-bound generation workloads and flexible budget-aware inference.

Abstract

The quadratic cost of attention in transformers motivated the development of efficient approaches: namely sparse and sliding window attention, convolutions and linear attention. Although these approaches result in impressive reductions in compute and memory, they often trade-off with quality, specifically in-context recall performance. Moreover, apriori fixing this quality-compute tradeoff means being suboptimal from the get-go: some downstream applications require more memory for in-context recall, while others require lower latency and memory. Further, these approaches rely on heuristic choices that artificially restrict attention, or require handcrafted and complex recurrent state update rules, or they must be carefully composed with attention at specific layers to form a hybrid architecture that complicates the design process, especially at scale. To address above issues, we propose Compress & Attend Transformer (CAT), a conceptually simple architecture employing two simple ingredients only: dense attention and compression. CAT decodes chunks of tokens by attending to compressed chunks of the sequence so far. Compression results in decoding from a reduced sequence length that yields compute and memory savings, while choosing a particular chunk size trades-off quality for efficiency. Moreover, CAT can be trained with multiple chunk sizes at once, unlocking control of quality-compute trade-offs directly at test-time without any retraining, all in a single adaptive architecture. In exhaustive evaluations on common language modeling tasks, in-context recall, and long-context understanding, a single adaptive CAT model outperforms existing efficient baselines, including hybrid architectures, across different compute-memory budgets. Further, a single CAT matches dense transformer in language modeling across model scales while being 1.4-3x faster and requiring 2-9x lower total memory usage.

Attention and Compression is all you need for Controllably Efficient Language Models

TL;DR

Transformer attention typically incurs compute in sequence length , limiting long-context modeling. CAT introduces a simple, end-to-end trainable compression step that summarizes past chunks and a dense, autoregressive decoder that attends to these compressed representations, enabling test-time control of the quality–compute trade-off via chunk size . A single adaptive CAT model (trained across multiple values) matches or surpasses dense transformers in speed and memory while outperforming efficient baselines on in-context recall, long-context understanding, and QA tasks, across varied budgets. The approach scales with model size and can be implemented in pure PyTorch without custom kernels, offering practical deployment advantages for memory-bound generation workloads and flexible budget-aware inference.

Abstract

The quadratic cost of attention in transformers motivated the development of efficient approaches: namely sparse and sliding window attention, convolutions and linear attention. Although these approaches result in impressive reductions in compute and memory, they often trade-off with quality, specifically in-context recall performance. Moreover, apriori fixing this quality-compute tradeoff means being suboptimal from the get-go: some downstream applications require more memory for in-context recall, while others require lower latency and memory. Further, these approaches rely on heuristic choices that artificially restrict attention, or require handcrafted and complex recurrent state update rules, or they must be carefully composed with attention at specific layers to form a hybrid architecture that complicates the design process, especially at scale. To address above issues, we propose Compress & Attend Transformer (CAT), a conceptually simple architecture employing two simple ingredients only: dense attention and compression. CAT decodes chunks of tokens by attending to compressed chunks of the sequence so far. Compression results in decoding from a reduced sequence length that yields compute and memory savings, while choosing a particular chunk size trades-off quality for efficiency. Moreover, CAT can be trained with multiple chunk sizes at once, unlocking control of quality-compute trade-offs directly at test-time without any retraining, all in a single adaptive architecture. In exhaustive evaluations on common language modeling tasks, in-context recall, and long-context understanding, a single adaptive CAT model outperforms existing efficient baselines, including hybrid architectures, across different compute-memory budgets. Further, a single CAT matches dense transformer in language modeling across model scales while being 1.4-3x faster and requiring 2-9x lower total memory usage.

Paper Structure

This paper contains 47 sections, 2 equations, 8 figures, 15 tables.

Figures (8)

  • Figure 1: cat unlocks test-time control of quality-efficiency trade-offs, where a single adaptive cat model (all red dots come from a single model) outperforms nearly every popular efficient architecture on real-world in-context recall tasks across varying compute-memory budgets.
  • Figure 2: The Compress and Attend Transformer (cat) architecture.cat chunks up a sequence of length $N$ into $N/C$ chunks of $C$ tokens (illustrated for $C=3$). Each chunk is parallelly compressed into a chunk representation. cat then decodes each chunk by attending to past chunk representations. Compression results in a reduced sequence length enabling compute and memory savings during decoding. Chunk size in cat acts as knob, offering test-time control of quality-efficiency trade-offs, where higher chunk sizes result in increased efficiency.
  • Figure 3: A single cat model generates $1.4-3.2\times$ faster than the dense transformer while showcasing upto $2.2-9.5\times$ lower memory usage. Per \ref{['tab:swde_fda_results']}, Cat-8 outperforms gdn-Hybrid in real-world recall tasks while being faster and requiring similar memory; cat-16 outperforms Mamba2 and gdn and is $1.15\times$ faster but costs slightly ($\sim15\%$) more memory.
  • Figure 4: Comparison of architectures on MQAR task (up to $4\times$ standard length). All models are memory-matched in bytes at every point (except dense transformer); cat outperforms baselines especially at longer sequences, while still using same memory.
  • Figure 5: cats scale like their dense transformer counterparts while being up to $3\times$ faster and $9\times$ more memory-efficient. All cat points come from a single model at a particular scale, evaluated at different chunk sizes.
  • ...and 3 more figures