Approximate Top-$k$ for Increased Parallelism
Oscar Key, Luka Ribar, Alberto Cattaneo, Luke Hudlass-Galley, Douglas Orr
TL;DR
This work tackles the bottleneck of exact top-$k$ by introducing bucketed approximate top-$k$ to unlock parallelism on accelerators. It analyzes a two-stage design with per-bucket top-$k_b$ and an optional final merge, explores parameter regimes ($k\ll n$ vs $k\propto n$), and provides both theoretical recall-cost bounds and extensive empirical evaluation. The results show substantial speed-ups (often $2$–$4\times$) across sparse-LM and KG tasks with minimal loss in downstream performance, and demonstrate meaningful end-to-end gains in SparQ-attention-based long-sequence generation. A CUDA/C++ PyTorch implementation is released to facilitate adoption, with guidance on strategy depending on the $k/n$ regime and potential extensions to distributed settings.
Abstract
We present an evaluation of bucketed approximate top-$k$ algorithms. Computing top-$k$ exactly suffers from limited parallelism, because the $k$ largest values must be aggregated along the vector, thus is not well suited to computation on highly-parallel machine learning accelerators. By relaxing the requirement that the top-$k$ is exact, bucketed algorithms can dramatically increase the parallelism available by independently computing many smaller top-$k$ operations. We explore the design choices of this class of algorithms using both theoretical analysis and empirical evaluation on downstream tasks. Our motivating examples are sparsity algorithms for language models, which often use top-$k$ to select the most important parameters or activations. We also release a fast bucketed top-$k$ implementation for PyTorch.
