Table of Contents
Fetching ...

Scatterbrain: Unifying Sparse and Low-rank Attention Approximation

Beidi Chen, Tri Dao, Eric Winsor, Zhao Song, Atri Rudra, Christopher Ré

TL;DR

Scatterbrain addresses the universal challenge of efficient Transformer attention by unifying sparse and low-rank approximations into an unbiased, low-variance estimator. By combining a sparse component identified via LSH with a kernel-based low-rank component, Scatterbrain adaptively captures both large and diffuse attention patterns, outperforming individual baselines and approaching an oracle Robust PCA reference. The method yields substantial memory savings (up to ~98%) with minimal accuracy loss in vision models and improves end-to-end metrics on language modeling and long-range tasks (e.g., up to ~5-point gains on LRA). This work provides a principled, hardware-friendly framework for robust attention approximation with broad applicability to autoregressive and bidirectional transformers, enabling scalable training and inference for long sequences.

Abstract

Recent advances in efficient Transformers have exploited either the sparsity or low-rank properties of attention matrices to reduce the computational and memory bottlenecks of modeling long sequences. However, it is still challenging to balance the trade-off between model quality and efficiency to perform a one-size-fits-all approximation for different tasks. To better understand this trade-off, we observe that sparse and low-rank approximations excel in different regimes, determined by the softmax temperature in attention, and sparse + low-rank can outperform each individually. Inspired by the classical robust-PCA algorithm for sparse and low-rank decomposition, we propose Scatterbrain, a novel way to unify sparse (via locality sensitive hashing) and low-rank (via kernel feature map) attention for accurate and efficient approximation. The estimation is unbiased with provably low error. We empirically show that Scatterbrain can achieve 2.1x lower error than baselines when serving as a drop-in replacement in BigGAN image generation and pre-trained T2T-ViT. On a pre-trained T2T Vision transformer, even without fine-tuning, Scatterbrain can reduce 98% of attention memory at the cost of only 1% drop in accuracy. We demonstrate Scatterbrain for end-to-end training with up to 4 points better perplexity and 5 points better average accuracy than sparse or low-rank efficient transformers on language modeling and long-range-arena tasks.

Scatterbrain: Unifying Sparse and Low-rank Attention Approximation

TL;DR

Scatterbrain addresses the universal challenge of efficient Transformer attention by unifying sparse and low-rank approximations into an unbiased, low-variance estimator. By combining a sparse component identified via LSH with a kernel-based low-rank component, Scatterbrain adaptively captures both large and diffuse attention patterns, outperforming individual baselines and approaching an oracle Robust PCA reference. The method yields substantial memory savings (up to ~98%) with minimal accuracy loss in vision models and improves end-to-end metrics on language modeling and long-range tasks (e.g., up to ~5-point gains on LRA). This work provides a principled, hardware-friendly framework for robust attention approximation with broad applicability to autoregressive and bidirectional transformers, enabling scalable training and inference for long sequences.

Abstract

Recent advances in efficient Transformers have exploited either the sparsity or low-rank properties of attention matrices to reduce the computational and memory bottlenecks of modeling long sequences. However, it is still challenging to balance the trade-off between model quality and efficiency to perform a one-size-fits-all approximation for different tasks. To better understand this trade-off, we observe that sparse and low-rank approximations excel in different regimes, determined by the softmax temperature in attention, and sparse + low-rank can outperform each individually. Inspired by the classical robust-PCA algorithm for sparse and low-rank decomposition, we propose Scatterbrain, a novel way to unify sparse (via locality sensitive hashing) and low-rank (via kernel feature map) attention for accurate and efficient approximation. The estimation is unbiased with provably low error. We empirically show that Scatterbrain can achieve 2.1x lower error than baselines when serving as a drop-in replacement in BigGAN image generation and pre-trained T2T-ViT. On a pre-trained T2T Vision transformer, even without fine-tuning, Scatterbrain can reduce 98% of attention memory at the cost of only 1% drop in accuracy. We demonstrate Scatterbrain for end-to-end training with up to 4 points better perplexity and 5 points better average accuracy than sparse or low-rank efficient transformers on language modeling and long-range-arena tasks.

Paper Structure

This paper contains 53 sections, 10 theorems, 47 equations, 7 figures, 8 tables, 1 algorithm.

Key Result

Theorem 1

Let $M_\beta$, be the attention matrix in ex:generative. Fix $\epsilon\in (0,1)$. Let $R \in \mathbb{R}^{n \times n}$ be a matrix. Consider low-rank, sparse, and sparse + low-rank approximations to $M_\beta$.

Figures (7)

  • Figure 1: Left: regimes that sparse+low-rank approximation is more accurate, based on the entropy of the attention matrices. Right: Scatterbrain Workflow. For the attention layer in Transformers, after computing Query $Q$, Key $K$, and Value $V$ matrices, we approximate $\mathrm{softmax}(QK^\top)V$ with two components: (i) sparse $SV$ (ii) low-rank $\phi(Q)(\phi(K)^\top V)$.
  • Figure 2: Visualization of the generative process, for three different values of the intra-cluster distance $\Delta$ (small, medium, and large). The vectors from the input sequence (rows of $Q$) form clusters that lie approximately on the unit sphere. Different colors represent different clusters.
  • Figure 3: Qualitative comparison of approx. accuracy and efficiency, among Robust PCA, sparse (Reformer) and low-rank (Performer) attention, and Scatterbrain. Scatterbrain is more accurate while being efficient.
  • Figure 4: Per-entry MSE for different approximations, across a range of magnitude of $q^\top k$.Scatterbrain has low MSE for both small and large entries, thus outperforming its sparse (Reformer) and low-rank (Performer) counterparts.
  • Figure 5: First: approximation comparison between Scatterbrain and its "lowerbound" Robust PCA. Second: comparison of error vs. entropy among SMYRF, Performer and Scatterbrain, three representatives of sparse, low-rank and sparse+low-rank approximations. Third and forth: Inception score (higher is better) and FID score (lower is better) of different attention variants for pretrained BigGAN.
  • ...and 2 more figures

Theorems & Definitions (22)

  • Theorem 1
  • Theorem 2
  • Example 1
  • Theorem 3
  • proof : Proof of \ref{['thm:sparse_lowrank_1']}
  • Example 2
  • Theorem 4
  • proof : Proof of \ref{['thm:sparse_lowrank_2']}
  • Lemma 5
  • proof
  • ...and 12 more