Table of Contents
Fetching ...

Why Attend to Everything? Focus is the Key

Hengshuai Yao, Xing Chen, Ahmed Murtadha, Jin Li, Shuai Shao, Yasin Abbasi Yadkori, Guan Wang, Mingli Yuan, William Chen, Sen Song

Abstract

We introduce Focus, a method that learns which token pairs matter rather than approximating all of them. Learnable centroids assign tokens to groups; distant attention is restricted to same-group pairs while local attention operates at full resolution. Because all model weights stay frozen, Focus is purely additive: centroid-only training (as few as 148K parameters) improves domain perplexity with zero degradation on downstream benchmarks--from 124M to 70B parameters, across five attention architectures. No existing efficient attention method achieves this in the retrofit setting. At 124M, Focus surpasses full attention (30.3 vs 31.4 PPL); trained from scratch at 7B scale (2B tokens), Focus again beats full attention (13.82 vs 13.89 PPL). At inference, restricting each token to its top-k highest-scoring groups discretizes the soft routing into a hard sparsity pattern, yielding 2x speedup while beating the pretrained baseline (41.3 vs 42.8 PPL); decomposing this pattern into two standard FlashAttention calls reaches 8.6x wall-clock speedup at 1M tokens with no custom kernels. Unlike LoRA, centroid routing preserves alignment: instruction-tuned models retain TruthfulQA scores after adaptation, while LoRA degrades at every learning rate and rank. Sinkhorn normalization enforces balanced groups as a hard constraint, and the resulting groups discover interpretable linguistic categories without supervision.

Why Attend to Everything? Focus is the Key

Abstract

We introduce Focus, a method that learns which token pairs matter rather than approximating all of them. Learnable centroids assign tokens to groups; distant attention is restricted to same-group pairs while local attention operates at full resolution. Because all model weights stay frozen, Focus is purely additive: centroid-only training (as few as 148K parameters) improves domain perplexity with zero degradation on downstream benchmarks--from 124M to 70B parameters, across five attention architectures. No existing efficient attention method achieves this in the retrofit setting. At 124M, Focus surpasses full attention (30.3 vs 31.4 PPL); trained from scratch at 7B scale (2B tokens), Focus again beats full attention (13.82 vs 13.89 PPL). At inference, restricting each token to its top-k highest-scoring groups discretizes the soft routing into a hard sparsity pattern, yielding 2x speedup while beating the pretrained baseline (41.3 vs 42.8 PPL); decomposing this pattern into two standard FlashAttention calls reaches 8.6x wall-clock speedup at 1M tokens with no custom kernels. Unlike LoRA, centroid routing preserves alignment: instruction-tuned models retain TruthfulQA scores after adaptation, while LoRA degrades at every learning rate and rank. Sinkhorn normalization enforces balanced groups as a hard constraint, and the resulting groups discover interpretable linguistic categories without supervision.

Paper Structure

This paper contains 125 sections, 2 theorems, 21 equations, 3 figures, 28 tables, 2 algorithms.

Key Result

Theorem 1

$\mathcal{A}$ and $\mathcal{B}$ satisfy: $\blacktriangleleft$$\blacktriangleleft$

Figures (3)

  • Figure 1: Quality--speed Pareto frontier of efficient attention retrofits. All methods start from pretrained GPT-2 124M, fine-tuned 4000 steps on PG-19. Speedups measured on H100-80GB at 1M tokens (Table \ref{['tab:speedup']}).
  • Figure 2: The forgetting--adaptation tradeoff. Left: LoRA ranks $r \in \{1, 2, 4, 8, 16\}$ trade domain PPL improvement for benchmark degradation. Centroids at both 148K ($d_g{=}16$, star) and 7.1M (full rank, diamond) achieve domain adaptation with exactly zero degradation---they are not on the LoRA tradeoff curve because they operate in a fundamentally different space (routing vs. weights). Right: Degradation across all four benchmarks for each LoRA rank. Both centroid lines (orange) are flat at zero regardless of projection dimension.
  • Figure 3: From-scratch 7B training curves: Focus (black, solid) vs full attention (blue, dashed). Focus leads at every checkpoint; the inset zooms into the final 13K steps where the gap stabilizes at $\sim$0.1 PPL.

Theorems & Definitions (2)

  • Theorem 1: Exact decomposition
  • Theorem 2: Exactness