A Mathematical Theory of Top-$k$ Sparse Attention via Total Variation Distance
Georgios Tzachristas, Lei Deng, Ioannis Tzachristas, Gong Zhang, Renhai Chen
TL;DR
This work provides a rigorous, distribution- and output-level theory for Top-$k$ attention truncation, tying the truncation error to the discarded softmax mass via an exact TV–KL identity. It derives deterministic and probabilistic bounds based on score gaps, blocks, and a head–tail decomposition, and translates them into practical certified selection algorithms (Delta-$k$-Search and MC-Search). The Gaussian score model yields closed-form design rules for the minimal $k$ to achieve a target error, guiding adaptive sparsification in long-context transformers. Empirical validation on BERT and synthetic data confirms scalable certified sparsity and substantial efficiency gains while preserving the prescribed total-variation budget. Overall, the framework offers provable, deployable sparsity with concrete guidelines for real-world Transformer inference and training workflows.
Abstract
We develop a unified mathematical framework for certified Top-$k$ attention truncation that quantifies approximation error at both the distribution and output levels. For a single attention distribution $P$ and its Top-$k$ truncation $\hat P$, we show that the total-variation distance coincides with the discarded softmax tail mass and satisfies $\mathrm{TV}(P,\hat P)=1-e^{-\mathrm{KL}(\hat P\Vert P)}$, yielding sharp Top-$k$-specific bounds in place of generic inequalities. From this we derive non-asymptotic deterministic bounds -- from a single boundary gap through multi-gap and blockwise variants -- that control $\mathrm{TV}(P,\hat P)$ using only the ordered logits. Using an exact head-tail decomposition, we prove that the output error factorizes as $\|\mathrm{Attn}(q,K,V)-\mathrm{Attn}_k(q,K,V)\|_2=τ\|μ_{\mathrm{tail}}-μ_{\mathrm{head}}\|_2$ with $τ=\mathrm{TV}(P,\hat P)$, yielding a new head-tail diameter bound $\|\mathrm{Attn}(q,K,V)-\mathrm{Attn}_k(q,K,V)\|_2\leτ\,\mathrm{diam}_{H,T}$ and refinements linking the error to $\mathrm{Var}_P(V)$. Under an i.i.d. Gaussian score model $s_i\sim\mathcal N(μ,σ^2)$ we derive closed-form tail masses and an asymptotic rule for the minimal $k_\varepsilon$ ensuring $\mathrm{TV}(P,\hat P)\le\varepsilon$, namely $k_\varepsilon/n\approxΦ_c(σ+Φ^{-1}(\varepsilon))$. Experiments on bert-base-uncased and synthetic logits confirm the predicted scaling of $k_\varepsilon/n$ and show that certified Top-$k$ can reduce scored keys by 2-4$\times$ on average while meeting the prescribed total-variation budget.
