Table of Contents
Fetching ...

Path Planning for Masked Diffusion Model Sampling

Fred Zhangzhi Peng, Zachary Bezemek, Sawan Patel, Jarrid Rector-Brooks, Sherwood Yao, Avishek Joey Bose, Alexander Tong, Pranam Chatterjee

TL;DR

This work addresses the suboptimal inference of masked discrete diffusion models by introducing Path Planning (P2), a planning-based inference framework that decomposes each generation step into planning and denoising. By expanding the ELBO to include a planner that selects which tokens to update and whether to remask, P2 enables refined trajectories and token revisions, surpassing existing MDM sampling strategies. The authors instantiate P2 with self-, BERT-, and trained-planners, achieving state-of-the-art results across protein design, RNA design, language tasks, and code generation, often with smaller model sizes than autoregressive baselines. The approach demonstrates robust improvements across multiple domains and provides a flexible, scalable framework for inference-time optimization in discrete diffusion models.

Abstract

Any order generation of discrete data using masked diffusion models (MDMs) offers a compelling alternative to traditional autoregressive models, especially in domains that lack a natural causal ordering of data. However, current popular MDMs depart from their successful continuous diffusion model counterparts with simplified masked inference wherein unmasked tokens cannot be iteratively refined -- even if there is a mistake. In this paper, we extract the full power of MDMs by introducing a novel inference sampling strategy termed Path Planning (P2) that decomposes each generation step into two sub-stages: planning and denoising. Under P2, the planner at every step selects appropriate tokens that are marked to be updated, which can then be sampled using the denoiser. We demonstrate that P2 generalizes all existing sampling strategies for MDMs and critically enhances generative quality through the new capability of refining and updating existing unmasked tokens. We theoretically prove that P2 establishes a (new) expanded evidence lower bound (ELBO) on the log marginal likelihood of data. We instantiate P2 with a family of planners including: 1.) Self-Planning, 2.) BERT-Planning, and 3.) Trained-Planning with a learned planner leading to SOTA generative performance for MDMs on a suite of domains. Specifically, solely using P2 inference, we observe relative improvements of 22% in protein sequence foldability, 8% in RNA sequence pLDDT, 4% in math reasoning, 68% in story generation (ROUGE score), and 33% in code generation for the challenging pass@1 metric.

Path Planning for Masked Diffusion Model Sampling

TL;DR

This work addresses the suboptimal inference of masked discrete diffusion models by introducing Path Planning (P2), a planning-based inference framework that decomposes each generation step into planning and denoising. By expanding the ELBO to include a planner that selects which tokens to update and whether to remask, P2 enables refined trajectories and token revisions, surpassing existing MDM sampling strategies. The authors instantiate P2 with self-, BERT-, and trained-planners, achieving state-of-the-art results across protein design, RNA design, language tasks, and code generation, often with smaller model sizes than autoregressive baselines. The approach demonstrates robust improvements across multiple domains and provides a flexible, scalable framework for inference-time optimization in discrete diffusion models.

Abstract

Any order generation of discrete data using masked diffusion models (MDMs) offers a compelling alternative to traditional autoregressive models, especially in domains that lack a natural causal ordering of data. However, current popular MDMs depart from their successful continuous diffusion model counterparts with simplified masked inference wherein unmasked tokens cannot be iteratively refined -- even if there is a mistake. In this paper, we extract the full power of MDMs by introducing a novel inference sampling strategy termed Path Planning (P2) that decomposes each generation step into two sub-stages: planning and denoising. Under P2, the planner at every step selects appropriate tokens that are marked to be updated, which can then be sampled using the denoiser. We demonstrate that P2 generalizes all existing sampling strategies for MDMs and critically enhances generative quality through the new capability of refining and updating existing unmasked tokens. We theoretically prove that P2 establishes a (new) expanded evidence lower bound (ELBO) on the log marginal likelihood of data. We instantiate P2 with a family of planners including: 1.) Self-Planning, 2.) BERT-Planning, and 3.) Trained-Planning with a learned planner leading to SOTA generative performance for MDMs on a suite of domains. Specifically, solely using P2 inference, we observe relative improvements of 22% in protein sequence foldability, 8% in RNA sequence pLDDT, 4% in math reasoning, 68% in story generation (ROUGE score), and 33% in code generation for the challenging pass@1 metric.

Paper Structure

This paper contains 72 sections, 1 theorem, 54 equations, 17 figures, 13 tables, 5 algorithms.

Key Result

Proposition 1

Define $P^{\theta,\phi}_0\in \Delta^{Ld}$ by $P^{\theta,\phi}_0(\mathbf{x})=\mathbb{P}(X^{\theta,\phi}_0=\mathbf{x})$, where $X^{\theta,\phi}$ is the continuous time Markov chain resulting from sending $T\to \inf$ in the discrete-time P2 formulation of subsec:MathematicalFormulations. Then we have a Here $\mathbf{p}_t$ is defined per Eq. eqn:forward_transition_kernel, and $\mathbf{z}^{-i}$ denotes

Figures (17)

  • Figure 1: Illustration of P2 sampling (\ref{['alg:OURpracticalsampling']}). At each step, the denoiser $D_\theta$ predicts $z$, and the planner $G_\phi$ selects positions to unmask (green) and remask (red).
  • Figure 2: Visualizing the predicted structures of generated protein (top) and RNA (bottom) sequences. Additional structures depicted in \ref{['fig:protein_structures_group1']}.
  • Figure 3: Inference-time Scaling: Foldability vs. Sampling steps.
  • Figure 4: Runtime (bar) and throughput (line) for different planner sizes (150M denoiser on an A100).
  • Figure S1: Protein Sequence Generation Benchmark: Performance across length categories (200–800).
  • ...and 12 more figures

Theorems & Definitions (1)

  • Proposition 1