Table of Contents
Fetching ...

Multipole Semantic Attention: A Fast Approximation of Softmax Attention for Pretraining

Rupert Mitchell, Kristian Kersting

TL;DR

MuSe tackles the quadratic bottleneck of softmax attention in long-context pretraining by introducing a two-level, query-key semantic clustering approach. It clusters queries and keys separately in their learned representations, builds query-specific, exponentially tilted summaries, and augments them with selective exact retrieval from the most relevant clusters, enabling a linear-like, fast approximation that remains compatible with pretrained models at test time. Empirical results show MuSe delivers up to a 36% wall-clock speedup at 64k context on 1B-scale models while preserving training quality and long-context utilization; it also generalizes to existing pretrained Llama models with minimal adaptation. The approach provides a practical pathway to scalable long-context pretraining with minimal architectural disruption, and the authors provide extensive ablations demonstrating the importance of query clustering and the effectiveness of retrieval-based corrections. Overall, MuSe achieves strong speedups without sacrificing accuracy and offers a flexible framework for accelerating attention in large-scale transformers, with potential for future kernel optimizations and broader applicability.

Abstract

Pretraining transformers on long sequences (entire code repositories, collections of related documents) is bottlenecked by quadratic attention costs. We present Multipole Semantic Attention (MuSe), which accelerates 64k-context pretraining by 36% while matching baseline loss, requiring no architectural changes. MuSe clusters queries and keys separately in representation space. This yields query-specific summaries that substantially outperform spatial blocking at matched sparsity, while also enabling drop-in compatibility with existing pretrained models; we validate on Llama 3.1-8B and 3.2-1B without retraining. We pretrain language models up to 1B parameters at 64k context on code and scientific documents, confirming that MuSe preserves quality and long-context utilization during training.

Multipole Semantic Attention: A Fast Approximation of Softmax Attention for Pretraining

TL;DR

MuSe tackles the quadratic bottleneck of softmax attention in long-context pretraining by introducing a two-level, query-key semantic clustering approach. It clusters queries and keys separately in their learned representations, builds query-specific, exponentially tilted summaries, and augments them with selective exact retrieval from the most relevant clusters, enabling a linear-like, fast approximation that remains compatible with pretrained models at test time. Empirical results show MuSe delivers up to a 36% wall-clock speedup at 64k context on 1B-scale models while preserving training quality and long-context utilization; it also generalizes to existing pretrained Llama models with minimal adaptation. The approach provides a practical pathway to scalable long-context pretraining with minimal architectural disruption, and the authors provide extensive ablations demonstrating the importance of query clustering and the effectiveness of retrieval-based corrections. Overall, MuSe achieves strong speedups without sacrificing accuracy and offers a flexible framework for accelerating attention in large-scale transformers, with potential for future kernel optimizations and broader applicability.

Abstract

Pretraining transformers on long sequences (entire code repositories, collections of related documents) is bottlenecked by quadratic attention costs. We present Multipole Semantic Attention (MuSe), which accelerates 64k-context pretraining by 36% while matching baseline loss, requiring no architectural changes. MuSe clusters queries and keys separately in representation space. This yields query-specific summaries that substantially outperform spatial blocking at matched sparsity, while also enabling drop-in compatibility with existing pretrained models; we validate on Llama 3.1-8B and 3.2-1B without retraining. We pretrain language models up to 1B parameters at 64k context on code and scientific documents, confirming that MuSe preserves quality and long-context utilization during training.

Paper Structure

This paper contains 80 sections, 6 equations, 6 figures, 26 tables, 1 algorithm.

Figures (6)

  • Figure 1: MuSe method overview, depicting the far-field approximation. Bars: Queries (top) and keys/values (bottom) are partitioned by semantic cluster (A/B for queries, X/Y for keys) within each spatial block (0, 1, 2); segments vary in size because cluster membership is data-dependent. Lower cube: Per-block summaries indexed by (query cluster, key cluster, spatial block), constructed in Steps 1--2. Upper cube: Causally accumulated summaries (Step 3), accumulated along the spatial axis. Steps 4--6: A query in segment B2 (blue line) attends to the upper cube to select key cluster Y (Step 4, semantic selection), then to the lower cube to select spatial block 1 (Step 5, spatial selection), and finally attends exactly to keys in (Y, 1) (Step 6). Greyed-out segments do not participate: the first query block has no prior context, and the final key/value block is only accessed via local attention (computed separately). See Algorithm \ref{['alg:muse']} for pseudocode.
  • Figure 2: Geometric interpretation of MuSe. Green circles represent key/value clusters; the blue dashed circle represents a query cluster. Hollow dots mark the exponentially-tilted centroids---cluster summaries shifted toward the query cluster. Solid dots mark individual key/value pairs selected for exact retrieval. The chosen query ($\times$) connects to tilted centroids via dotted lines (approximate attention) and to retrieved tokens via solid lines (exact attention).
  • Figure 3: Training loss versus wall-clock time for 1B models on code (left) and scientific PDF (right) domains. MuSe (blue) and CUDNN (black) follow the same loss trajectory, shifted horizontally---the gap is the 36% throughput advantage. Plots start after warmup (4000 steps). Trained on a single node of 8 NVIDIA A100 GPUs.
  • Figure 4: Scaling behavior from 96M to 1B parameters. Power law fits (lines) span 185M--1B, where scaling is cleanest; the 96M point falls slightly above the fit. MuSe-trained models evaluated with CUDNN attention (orange triangles, $\dagger$) track the exact attention baseline, confirming that the approximation preserves scaling properties. $\dagger$The 1B point uses the fine-tuned value after 0.1% additional CUDNN training.
  • Figure 5: Cumulative mean loss versus context length on Project Gutenberg text for Llama 3.2-1B (left) and Llama 3.1-8B (right). MuSe with various cluster counts (QK64--QK512) tracks exact attention closely, with the gap reducing by ${\sim}1.8\times$ per doubling for 1B and ${\sim}2\times$ for 8B. Without exponential tilting ("No tilt K128", dotted), loss explodes immediately. Query clustering provides additional benefit: QK64 outperforms Q1 K128 despite using half the key clusters, demonstrating that clustering queries more than doubles practical quality. Y-axes show reducible loss (cumulative mean minus fitted irreducible loss: 2.30 nats for 1B, 1.96 nats for 8B).
  • ...and 1 more figures