Path Planning for Masked Diffusion Model Sampling
Fred Zhangzhi Peng, Zachary Bezemek, Sawan Patel, Jarrid Rector-Brooks, Sherwood Yao, Avishek Joey Bose, Alexander Tong, Pranam Chatterjee
TL;DR
This work addresses the suboptimal inference of masked discrete diffusion models by introducing Path Planning (P2), a planning-based inference framework that decomposes each generation step into planning and denoising. By expanding the ELBO to include a planner that selects which tokens to update and whether to remask, P2 enables refined trajectories and token revisions, surpassing existing MDM sampling strategies. The authors instantiate P2 with self-, BERT-, and trained-planners, achieving state-of-the-art results across protein design, RNA design, language tasks, and code generation, often with smaller model sizes than autoregressive baselines. The approach demonstrates robust improvements across multiple domains and provides a flexible, scalable framework for inference-time optimization in discrete diffusion models.
Abstract
Any order generation of discrete data using masked diffusion models (MDMs) offers a compelling alternative to traditional autoregressive models, especially in domains that lack a natural causal ordering of data. However, current popular MDMs depart from their successful continuous diffusion model counterparts with simplified masked inference wherein unmasked tokens cannot be iteratively refined -- even if there is a mistake. In this paper, we extract the full power of MDMs by introducing a novel inference sampling strategy termed Path Planning (P2) that decomposes each generation step into two sub-stages: planning and denoising. Under P2, the planner at every step selects appropriate tokens that are marked to be updated, which can then be sampled using the denoiser. We demonstrate that P2 generalizes all existing sampling strategies for MDMs and critically enhances generative quality through the new capability of refining and updating existing unmasked tokens. We theoretically prove that P2 establishes a (new) expanded evidence lower bound (ELBO) on the log marginal likelihood of data. We instantiate P2 with a family of planners including: 1.) Self-Planning, 2.) BERT-Planning, and 3.) Trained-Planning with a learned planner leading to SOTA generative performance for MDMs on a suite of domains. Specifically, solely using P2 inference, we observe relative improvements of 22% in protein sequence foldability, 8% in RNA sequence pLDDT, 4% in math reasoning, 68% in story generation (ROUGE score), and 33% in code generation for the challenging pass@1 metric.
