Parallel Sampling via Counting
Nima Anari, Ruiquan Gao, Aviad Rubinstein
TL;DR
This work addresses the problem of sampling from an arbitrary distribution $\mu$ on $[q]^n$ using a counting oracle that exposes pinning probabilities. The authors introduce a sublinear-in-$n$ parallel sampling algorithm with expected round complexity $O\big(n^{2/3}\cdot\min\{\log^{2/3}n\log q,\; q^{1/3}\log^{1/3}q\}\big)$, enabled by pinning lemmas and a robust universal coupler that coordinates parallel marginal queries. They also prove a tight-looking lower bound of $\tilde{\Omega}(n^{1/3})$ rounds for any poly$(n)$-query parallel sampler, and demonstrate practical implications for autoregressive models and planar graph problems (notably planar perfect matchings) where parallel speedups are achievable. The combination of random coordinate permutation, parallel marginal guessing, and verification via coupling provides a principled pathway to sublinear parallel sampling from general distributions, with meaningful implications for generative AI and combinatorial sampling.
Abstract
We show how to use parallelization to speed up sampling from an arbitrary distribution $μ$ on a product space $[q]^n$, given oracle access to counting queries: $\mathbb{P}_{X\sim μ}[X_S=σ_S]$ for any $S\subseteq [n]$ and $σ_S \in [q]^S$. Our algorithm takes $O({n^{2/3}\cdot \operatorname{polylog}(n,q)})$ parallel time, to the best of our knowledge, the first sublinear in $n$ runtime for arbitrary distributions. Our results have implications for sampling in autoregressive models. Our algorithm directly works with an equivalent oracle that answers conditional marginal queries $\mathbb{P}_{X\sim μ}[X_i=σ_i\;\vert\; X_S=σ_S]$, whose role is played by a trained neural network in autoregressive models. This suggests a roughly $n^{1/3}$-factor speedup is possible for sampling in any-order autoregressive models. We complement our positive result by showing a lower bound of $\widetildeΩ(n^{1/3})$ for the runtime of any parallel sampling algorithm making at most $\operatorname{poly}(n)$ queries to the counting oracle, even for $q=2$.
