Foundations of Top-$k$ Decoding For Language Models
Georgy Noarov, Soham Mallick, Tao Wang, Sunay Joshi, Yan Sun, Yangxinyu Xie, Mengxin Yu, Edgar Dobriban
TL;DR
This work provides a theoretical foundation for top-$k$ decoding in language models by modeling decoding as recovering a sparse distribution through sparsity-regularized, separable Bregman divergences. It introduces primal and dual Bregman decoding frameworks with $\ell_0$ regularization, proving that optimal supports are greedy top-$k$ sets and that the cost in $k$ is discretely convex, enabling efficient adaptive-$k$ search. The authors develop renormalization maps for fixed sparsity patterns and establish conditions under which both primal and dual decoding are tractable, including an alpha-entropy family of decoders that generalize top-$k$ (recovered at $\alpha=1$) and exhibit varied mass-shifting behavior. Experiments on open-ended generation and math reasoning show competitive performance with standard top-$k$ decoding and demonstrate the practical viability of adaptive sparsity and alpha-based strategies for decoding in large language models.
Abstract
Top-$k$ decoding is a widely used method for sampling from LLMs: at each token, only the largest $k$ next-token-probabilities are kept, and the next token is sampled after re-normalizing them to sum to unity. Top-$k$ and other sampling methods are motivated by the intuition that true next-token distributions are sparse, and the noisy LLM probabilities need to be truncated. However, to our knowledge, a precise theoretical motivation for the use of top-$k$ decoding is missing. In this work, we develop a theoretical framework that both explains and generalizes top-$k$ decoding. We view decoding at a fixed token as the recovery of a sparse probability distribution. We consider \emph{Bregman decoders} obtained by minimizing a separable Bregman divergence (for both the \emph{primal} and \emph{dual} cases) with a sparsity-inducing $\ell_0$ regularization. Despite the combinatorial nature of the objective, we show how to optimize it efficiently for a large class of divergences. We show that the optimal decoding strategies are greedy, and further that the loss function is discretely convex in $k$, so that binary search provably and efficiently finds the optimal $k$. We show that top-$k$ decoding arises as a special case for the KL divergence, and identify new decoding strategies that have distinct behaviors (e.g., non-linearly up-weighting larger probabilities after re-normalization).
