Table of Contents
Fetching ...

FlashSampling: Fast and Memory-Efficient Exact Sampling

Tomas Ruiz, Zhen Qin, Yifan Zhang, Xuyang Shen, Yiran Zhong, Mengdi Wang

Abstract

Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because $\argmax$ decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to $19%$ on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.

FlashSampling: Fast and Memory-Efficient Exact Sampling

Abstract

Sampling from a categorical distribution is mathematically simple, but in large-vocabulary decoding, it often triggers extra memory traffic and extra kernels after the LM head. We present FlashSampling, an exact sampling primitive that fuses sampling into the LM-head matmul and never materializes the logits tensor in HBM. The method is simple: compute logits tile-by-tile on chip, add Gumbel noise, keep only one maximizer per row and per vocabulary tile, and finish with a small reduction over tiles. The fused tiled kernel is exact because decomposes over a partition; grouped variants for online and tensor-parallel settings are exact by hierarchical factorization of the categorical distribution. Across H100, H200, B200, and B300 GPUs, FlashSampling speeds up kernel-level decode workloads, and in end-to-end vLLM experiments, it reduces time per output token by up to on the models we test. These results show that exact sampling, with no approximation, can be integrated into the matmul itself, turning a bandwidth-bound postprocessing step into a lightweight epilogue. Project Page: https://github.com/FlashSampling/FlashSampling.
Paper Structure (66 sections, 6 theorems, 25 equations, 6 figures, 6 tables, 6 algorithms)

This paper contains 66 sections, 6 theorems, 25 equations, 6 figures, 6 tables, 6 algorithms.

Key Result

Theorem 2.1

Let $\widetilde{\bm{\ell}}\in(\mathbb{R}\cup\{-\infty\})^V$ have at least one finite entry, and let $\{g_i\}_{i=1}^V$ be i.i.d. $\mathrm{Gumbel}(0,1)$. Then

Figures (6)

  • Figure 1: Conventional multinomial sampling (left) materializes the full $[B,V]$ logits tensor in HBM between the matmul and the sampler. FlashSampling (right) fuses sampling into the matmul epilogue, followed by a lightweight reduction over vocabulary tiles. Logits are computed tile-by-tile in on-chip memory, perturbed with Gumbel noise, and reduced without ever writing the full logits tensor to HBM. Red arrows denote HBM traffic; green arrows denote on-chip data movement.
  • Figure 2: Relative performance on B300. Left: FlashSampling vs. the Multinomial Sampling (baseline $=1$). Right: FlashSampling vs. FlashInfer FI1 and FI2 (baseline $=1$). FlashSampling is faster than the Multinomial Sampling across all shown batch sizes, faster than FI1 throughout, and faster than FI2 in the decode regime.
  • Figure 3: Sampling runtime (left) and matmul runtime (right) in $\mu$s vs. batch size. Lower is better.
  • Figure 4: Roofline (left) and HBM bandwidth utilization (right) on H100. Left: all methods track the memory-bound slope for $B \le 64$; FlashSampling sits slightly above baselines because it avoids the logits round-trip. Close to the ridge point ($\mathrm{AI} \approx 295$), performance flattens below the compute ceiling, where cuBLAS outperforms Triton. Right: FlashSampling achieves higher bandwidth utilization than all baselines in the decode regime, confirming that fusion removes overhead rather than shifting it. Appendix \ref{['app:roofline_b200']} shows the same pattern on B200.
  • Figure 5: TPOT vs. concurrency on B200 for all four models. Top row: Qwen3-1.7B (up to $19\%$ reduction) and Qwen3-8B (roughly $3$--$7\%$). Bottom row: Qwen3-32B and gpt-oss-120b, where gains are smaller because attention and FFN dominate decode time.
  • ...and 1 more figures

Theorems & Definitions (6)

  • Theorem 2.1: Gumbel-Max
  • Lemma 4.1: Gumbel max-stability under grouping
  • Lemma 4.2: Exact group factorization
  • Lemma 4.3: Binary merge rule
  • Theorem 4.4: Exactness of hierarchical FlashSampling
  • Lemma 4.5: Max over vocabulary tiles