Steering Masked Discrete Diffusion Models via Discrete Denoising Posterior Prediction
Jarrid Rector-Brooks, Mohsin Hasan, Zhangzhi Peng, Zachary Quinn, Chenghao Liu, Sarthak Mittal, Nouha Dziri, Michael Bronstein, Yoshua Bengio, Pranam Chatterjee, Alexander Tong, Avishek Joey Bose
TL;DR
This work tackles steering discrete diffusion models by reframing the problem as sampling from a reward-augmented Bayesian posterior. The authors introduce DDPP, a simulation-free framework that finetunes a pre-trained Masked Diffusion Model to approximate the posterior $\pi_0(\mathbf{x}_0) \propto p_0^{\text{pre}}(\mathbf{x}_0) R(\mathbf{x}_0)$, using forward masking and three learning objectives: DDPP-IS, DDPP-LB, and DDPP-KL. Across synthetic tasks, pixel-level image modeling, protein sequence design, and text generation, DDPP variants demonstrate strong steering performance with competitive sample quality, and even enable wet-lab validation of designed proteins. The framework unifies RLHF-style objectives with discrete diffusion priors in a scalable, simulation-free manner, offering a practical path to controllable generation in diverse discrete domains.
Abstract
Generative modeling of discrete data underlies important applications spanning text-based agents like ChatGPT to the design of the very building blocks of life in protein sequences. However, application domains need to exert control over the generated data by steering the generative process - typically via RLHF - to satisfy a specified property, reward, or affinity metric. In this paper, we study the problem of steering Masked Diffusion Models (MDMs), a recent class of discrete diffusion models that offer a compelling alternative to traditional autoregressive models. We introduce Discrete Denoising Posterior Prediction (DDPP), a novel framework that casts the task of steering pre-trained MDMs as a problem of probabilistic inference by learning to sample from a target Bayesian posterior. Our DDPP framework leads to a family of three novel objectives that are all simulation-free, and thus scalable while applying to general non-differentiable reward functions. Empirically, we instantiate DDPP by steering MDMs to perform class-conditional pixel-level image modeling, RLHF-based alignment of MDMs using text-based rewards, and finetuning protein language models to generate more diverse secondary structures and shorter proteins. We substantiate our designs via wet-lab validation, where we observe transient expression of reward-optimized protein sequences.
