Table of Contents
Fetching ...

Approximate Top-$k$ for Increased Parallelism

Oscar Key, Luka Ribar, Alberto Cattaneo, Luke Hudlass-Galley, Douglas Orr

TL;DR

This work tackles the bottleneck of exact top-$k$ by introducing bucketed approximate top-$k$ to unlock parallelism on accelerators. It analyzes a two-stage design with per-bucket top-$k_b$ and an optional final merge, explores parameter regimes ($k\ll n$ vs $k\propto n$), and provides both theoretical recall-cost bounds and extensive empirical evaluation. The results show substantial speed-ups (often $2$–$4\times$) across sparse-LM and KG tasks with minimal loss in downstream performance, and demonstrate meaningful end-to-end gains in SparQ-attention-based long-sequence generation. A CUDA/C++ PyTorch implementation is released to facilitate adoption, with guidance on strategy depending on the $k/n$ regime and potential extensions to distributed settings.

Abstract

We present an evaluation of bucketed approximate top-$k$ algorithms. Computing top-$k$ exactly suffers from limited parallelism, because the $k$ largest values must be aggregated along the vector, thus is not well suited to computation on highly-parallel machine learning accelerators. By relaxing the requirement that the top-$k$ is exact, bucketed algorithms can dramatically increase the parallelism available by independently computing many smaller top-$k$ operations. We explore the design choices of this class of algorithms using both theoretical analysis and empirical evaluation on downstream tasks. Our motivating examples are sparsity algorithms for language models, which often use top-$k$ to select the most important parameters or activations. We also release a fast bucketed top-$k$ implementation for PyTorch.

Approximate Top-$k$ for Increased Parallelism

TL;DR

This work tackles the bottleneck of exact top- by introducing bucketed approximate top- to unlock parallelism on accelerators. It analyzes a two-stage design with per-bucket top- and an optional final merge, explores parameter regimes ( vs ), and provides both theoretical recall-cost bounds and extensive empirical evaluation. The results show substantial speed-ups (often ) across sparse-LM and KG tasks with minimal loss in downstream performance, and demonstrate meaningful end-to-end gains in SparQ-attention-based long-sequence generation. A CUDA/C++ PyTorch implementation is released to facilitate adoption, with guidance on strategy depending on the regime and potential extensions to distributed settings.

Abstract

We present an evaluation of bucketed approximate top- algorithms. Computing top- exactly suffers from limited parallelism, because the largest values must be aggregated along the vector, thus is not well suited to computation on highly-parallel machine learning accelerators. By relaxing the requirement that the top- is exact, bucketed algorithms can dramatically increase the parallelism available by independently computing many smaller top- operations. We explore the design choices of this class of algorithms using both theoretical analysis and empirical evaluation on downstream tasks. Our motivating examples are sparsity algorithms for language models, which often use top- to select the most important parameters or activations. We also release a fast bucketed top- implementation for PyTorch.

Paper Structure

This paper contains 24 sections, 5 equations, 19 figures, 1 table.

Figures (19)

  • Figure 1: Our approximate top-$k$ implementation (), compared with exact top-$k$ implementations from PyTorch and RAFT, and a bucketed top-$k$ using torch.argmax, tested in float32 on an H100 PCIe GPU with batch size $m=128$. Total bandwidth is the minimum number of bytes transferred by top-$k$, divided by runtime. Left: Small fixed $k=64$; it is faster to retrieve $k_b=1$ element per bucket, varying the total number $b\cdot k_b$ of elements retrieved, where $b$ is the number of buckets. Right: Large $k=n/4$; best to set $b\cdot k_b/k=1$ and increase $k_b$.
  • Figure 2: Left: An example of a bucketed top-$k$, with $n=11$, $k=4$, $b=3$ and $k_b=2$. In Stage 1, $n$ elements are reduced to $b \cdot k_b$ elements via $b$ independent top-$k_b$. An optional Stage 2 takes final top-$k$. Right: The trade-off between top-$k$ runtime duration and downstream task accuracy for SparQ Attention in SQuAD and a sequence repetition task (see \ref{['app:downstream-task-details']}), when using different bucketed top-$k$ settings with batch size $m=1$. Good accuracy and speed-ups above $4\times$ are achieved with $k_b\in\{2,4\}$.
  • Figure 3: Theoretical trade-off curves, using the serial cost model, which computes the count of all operations executed in an abstract execution model (\ref{['app:cost-models']}), for $n=1048576.0$ ($=2^{20}$), with $k \ll n$ (left, $k\!=\!256$) and $k \propto n$ (right, $k\!=\!n/8\!=\!131072.0$). Points along the curves (from bottom to top) indicate increasing the $b \cdot k_b/k$ ratio, leading to a decreasing recall error. See \ref{['appendix:tradeoff-curves']} for the full set of trade-off curves with various cost models.
  • Figure 4: Bucketed top-$k$ trade-off for LLM vocabulary sampling (left, $n=128256.0$, $k=256$, $m=64$) and for knowledge graph link prediction (right, $n=2653751.0$, $k=100$, $m=128$). In both regimes, $k_b=1$ gives peak performance, but $k_b=2$ sacrifices some speed for sake of a lower error and is Pareto optimal when increasing $b$.
  • Figure 5: End-to-end speed-ups achieved when generating text from Llama 2 7B using SparQ sparse attention, which relies on the top-$k$ operation. We use $k=n/8$ and batch size $1$, and set $b$ such that $b \cdot k_b = k$. shows the theoretical maximum speed-up that SparQ could achieve if the computation was fully memory-bandwidth-bound (see \ref{['app:e2e-speedup-details']}). We plot the mean over four repeats, observing little variance.
  • ...and 14 more figures