Table of Contents
Fetching ...

Enhancing Reasoning for Diffusion LLMs via Distribution Matching Policy Optimization

Yuchen Zhu, Wei Guo, Jaemoo Choi, Petr Molodyk, Bo Yuan, Molei Tao, Yongxin Chen

TL;DR

DMPO reframes RL fine-tuning for diffusion LLMs from reward maximization to distribution matching, leveraging a forward KL objective and importance sampling with weighted denoising cross-entropy to align the dLLM with a reward-tilted target distribution. It introduces weight-baseline techniques to stabilize training under small batch sizes and proposes Weighted Direct Discriminative Optimization as an additional distribution-matching option. Empirically, DMPO applied to a native dLLM (LLaDA-8B-Instruct) yields substantial gains on reasoning benchmarks without supervised fine-tuning, including a 42.9% improvement over prior SOTA baselines and 55.8% over the base model on planning tasks. The approach emphasizes off-policy, forward-only training that can leverage fast inference techniques to boost throughput while maintaining or improving reasoning quality.

Abstract

Diffusion large language models (dLLMs) are promising alternatives to autoregressive large language models (AR-LLMs), as they potentially allow higher inference throughput. Reinforcement learning (RL) is a crucial component for dLLMs to achieve comparable performance with AR-LLMs on important tasks, such as reasoning. However, RL algorithms that are well-suited for dLLMs' unique characteristics have yet to be developed. This paper proposes Distribution Matching Policy Optimization (DMPO), a principled and theoretically grounded RL fine-tuning method specifically designed to enhance the reasoning capabilities of dLLMs by matching the dLLM policy distribution to the optimal, reward-tilted one through cross-entropy optimization. We identify a key challenge in the implementation with a small training batch size and propose several effective solutions through a novel weight baseline subtraction technique. DMPO exhibits superior performance on multiple reasoning benchmarks without supervised fine-tuning, with an accuracy improvement of up to $42.9\%$ over previously SOTA baselines and $55.8\%$ over the base model, underscoring the effectiveness of the distribution matching framework. Our code is available at https://github.com/yuchen-zhu-zyc/DMPO.

Enhancing Reasoning for Diffusion LLMs via Distribution Matching Policy Optimization

TL;DR

DMPO reframes RL fine-tuning for diffusion LLMs from reward maximization to distribution matching, leveraging a forward KL objective and importance sampling with weighted denoising cross-entropy to align the dLLM with a reward-tilted target distribution. It introduces weight-baseline techniques to stabilize training under small batch sizes and proposes Weighted Direct Discriminative Optimization as an additional distribution-matching option. Empirically, DMPO applied to a native dLLM (LLaDA-8B-Instruct) yields substantial gains on reasoning benchmarks without supervised fine-tuning, including a 42.9% improvement over prior SOTA baselines and 55.8% over the base model on planning tasks. The approach emphasizes off-policy, forward-only training that can leverage fast inference techniques to boost throughput while maintaining or improving reasoning quality.

Abstract

Diffusion large language models (dLLMs) are promising alternatives to autoregressive large language models (AR-LLMs), as they potentially allow higher inference throughput. Reinforcement learning (RL) is a crucial component for dLLMs to achieve comparable performance with AR-LLMs on important tasks, such as reasoning. However, RL algorithms that are well-suited for dLLMs' unique characteristics have yet to be developed. This paper proposes Distribution Matching Policy Optimization (DMPO), a principled and theoretically grounded RL fine-tuning method specifically designed to enhance the reasoning capabilities of dLLMs by matching the dLLM policy distribution to the optimal, reward-tilted one through cross-entropy optimization. We identify a key challenge in the implementation with a small training batch size and propose several effective solutions through a novel weight baseline subtraction technique. DMPO exhibits superior performance on multiple reasoning benchmarks without supervised fine-tuning, with an accuracy improvement of up to over previously SOTA baselines and over the base model, underscoring the effectiveness of the distribution matching framework. Our code is available at https://github.com/yuchen-zhu-zyc/DMPO.

Paper Structure

This paper contains 29 sections, 45 equations, 5 figures, 1 table, 1 algorithm.

Figures (5)

  • Figure 1: Performances on reasoning benchmarks evaluated with generation length $256$. DMPO constantly achieves the best performances across bidirectional dLLM, outperforming d1.
  • Figure 2: Illustration of relative entropy (mode-seeking) and cross-entropy (mass-covering) for fitting a target $p_*$ ($\mathcal{G}$ is the set of Gaussian distributions)
  • Figure 3: Demonstration of the effect of weight baseline. The orange and blue curves represent the probability $p_\theta({\bm{o}}|{\bm{q}})$before and after update, and the magenta arrows represent the weights. (a) When batch size is large, distribution mode coverage is good. Though bad responses have positive weights, the correct ones will have larger weights to force the distribution updates towards the right direction. (b) When batch size is small, some modes (e.g., the good one in the middle) may not be sampled. Without weight baseline subtraction, the dominant positive weights of the bad responses lead to wrong update directions. (c) With weight baseline subtraction, the bad responses will appropriately be penalized, leading to the desired update direction.
  • Figure 4: Rewards dynamics during training. DMPO consistently produces higher rewards than d1.
  • Figure 5: Effects of negative gradient insertion for DMPO.