Table of Contents
Fetching ...

Attamba: Attending To Multi-Token States

Yash Akhauri, Safeen Huda, Mohamed S. Abdelfattah

TL;DR

Attamba is a novel architecture that uses state-space models to compress chunks of tokens and applies attention on these compressed key-value representations, enabling a smooth transition between quadratic and linear scaling, offering adaptable efficiency gains.

Abstract

When predicting the next token in a sequence, vanilla transformers compute attention over all previous tokens, resulting in quadratic scaling of compute with sequence length. State-space models compress the entire sequence of tokens into a fixed-dimensional representation to improve efficiency, while other architectures achieve sub-quadratic complexity via low-rank projections or sparse attention patterns over the sequence. In this paper, we introduce Attamba, a novel architecture that uses state-space models to compress chunks of tokens and applies attention on these compressed key-value representations. We find that replacing key and value projections in a transformer with SSMs can improve model quality and enable flexible token chunking, resulting in 24% improved perplexity with transformer of similar KV-Cache and attention footprint, and ~4 times smaller KV-Cache and Attention FLOPs for 5% perplexity trade-off. Attamba can perform attention on chunked-sequences of variable length, enabling a smooth transition between quadratic and linear scaling, offering adaptable efficiency gains.

Attamba: Attending To Multi-Token States

TL;DR

Attamba is a novel architecture that uses state-space models to compress chunks of tokens and applies attention on these compressed key-value representations, enabling a smooth transition between quadratic and linear scaling, offering adaptable efficiency gains.

Abstract

When predicting the next token in a sequence, vanilla transformers compute attention over all previous tokens, resulting in quadratic scaling of compute with sequence length. State-space models compress the entire sequence of tokens into a fixed-dimensional representation to improve efficiency, while other architectures achieve sub-quadratic complexity via low-rank projections or sparse attention patterns over the sequence. In this paper, we introduce Attamba, a novel architecture that uses state-space models to compress chunks of tokens and applies attention on these compressed key-value representations. We find that replacing key and value projections in a transformer with SSMs can improve model quality and enable flexible token chunking, resulting in 24% improved perplexity with transformer of similar KV-Cache and attention footprint, and ~4 times smaller KV-Cache and Attention FLOPs for 5% perplexity trade-off. Attamba can perform attention on chunked-sequences of variable length, enabling a smooth transition between quadratic and linear scaling, offering adaptable efficiency gains.

Paper Structure

This paper contains 20 sections, 11 equations, 20 figures, 1 table.

Figures (20)

  • Figure 1: Attamba uses State-Space Models (SSM) to compress key-value sequences into token chunks (e.g., chunks of $P = 4$), reducing the attention map and KV-Cache size by $P\times$ by storing only chunk boundaries.
  • Figure 2: State Space Models (SSMs) efficiently encode multiple tokens into a single representation. By compressing key ($K$) and value ($V$) sequences into chunked representations, SSMs maintain essential contextual information, enabling efficient query ($Q$) interactions. This approach minimizes KV-Cache size by storing only chunk boundaries and reduces the computational cost of attention. Attamba demonstrates robustness to randomized chunk boundaries, indicating the potential for flexible computation-quality trade-offs. Approximate FLOPs/Memory shown, constants ignored. Variables: $L$ (Sequence length), $P$ (Chunk size), $D_{S}$ (SSM state dimension), $E$ (Model dimension).
  • Figure 3: Attamba uses SSM blocks to compress chunks of tokens ($P = 4$ in the example above) into a single token.
  • Figure 4: Full-Attention has a purely causal mask, attending to all past tokens. Attamba uses Key-Value SSM blocks to compress chunks of $P$ tokens (e.g. $P=4$) into one state. Tokens compressed by SSMs are at chunk boundaries. This is incorporated with a sliding-window attention (when $L > 1$). At test-time (inference), only the chunk boundaries and sliding window tokens need to be preserved, reducing KV-Cache and Attention FLOPs.
  • Figure 5: Leading-Tokens (L) control how many 'leading' tokens full-attention happens over, preserving full-attention on the newest tokens. This resembles Sliding-Window attention. Chunk-size (P) controls how many tokens are chunked by the SSM.
  • ...and 15 more figures