Table of Contents
Fetching ...

dTRPO: Trajectory Reduction in Policy Optimization of Diffusion Large Language Models

Wenxuan Zhang, Lemeng Wu, Changsheng Zhao, Ernie Chang, Mingchen Zhuge, Zechun Liu, Andy Su, Hanxian Huang, Jun Chen, Chong Zhou, Raghuraman Krishnamoorthi, Vikas Chandra, Mohamed Elhoseiny, Wei Wen

Abstract

Diffusion Large Language Models (dLLMs) introduce a new paradigm for language generation, which in turn presents new challenges for aligning them with human preferences. In this work, we aim to improve the policy optimization for dLLMs by reducing the cost of the trajectory probability calculation, thereby enabling scaled-up offline policy training. We prove that: (i) under reference policy regularization, the probability ratio of the newly unmasked tokens is an unbiased estimate of that of intermediate diffusion states, and (ii) the probability of the full trajectory can be effectively estimated with a single forward pass of a re-masked final state. By integrating these two trajectory reduction strategies into a policy optimization objective, we propose Trajectory Reduction Policy Optimization (dTRPO). We evaluate dTRPO on 7B dLLMs across instruction-following and reasoning benchmarks. Results show that it substantially improves the core performance of state-of-the-art dLLMs, achieving gains of up to 9.6% on STEM tasks, up to 4.3% on coding tasks, and up to 3.0% on instruction-following tasks. Moreover, dTRPO exhibits strong training efficiency due to its offline, single-forward nature, and achieves improved generation efficiency through high-quality outputs.

dTRPO: Trajectory Reduction in Policy Optimization of Diffusion Large Language Models

Abstract

Diffusion Large Language Models (dLLMs) introduce a new paradigm for language generation, which in turn presents new challenges for aligning them with human preferences. In this work, we aim to improve the policy optimization for dLLMs by reducing the cost of the trajectory probability calculation, thereby enabling scaled-up offline policy training. We prove that: (i) under reference policy regularization, the probability ratio of the newly unmasked tokens is an unbiased estimate of that of intermediate diffusion states, and (ii) the probability of the full trajectory can be effectively estimated with a single forward pass of a re-masked final state. By integrating these two trajectory reduction strategies into a policy optimization objective, we propose Trajectory Reduction Policy Optimization (dTRPO). We evaluate dTRPO on 7B dLLMs across instruction-following and reasoning benchmarks. Results show that it substantially improves the core performance of state-of-the-art dLLMs, achieving gains of up to 9.6% on STEM tasks, up to 4.3% on coding tasks, and up to 3.0% on instruction-following tasks. Moreover, dTRPO exhibits strong training efficiency due to its offline, single-forward nature, and achieves improved generation efficiency through high-quality outputs.
Paper Structure (43 sections, 4 theorems, 55 equations, 6 figures, 4 tables, 1 algorithm)

This paper contains 43 sections, 4 theorems, 55 equations, 6 figures, 4 tables, 1 algorithm.

Key Result

Theorem 3.1

The probability of the MDP process in dLLMs can be reduced to where $\bm{\tau}_{s,t}=\bm{\tau}_{sT_B+t}$ denotes the state at block $s$ and within-block step $t$.

Figures (6)

  • Figure 1: Performance gains on MATH dataset v.s. normalized online (left) and offline (right) training cost. Online training requires hundrads more ($N\times$) computation for the rollout stage, while our offline method requires only 4 forward passes per training example and achieves comparable performance.
  • Figure 2: (a): Generation processes in ARMs and dLLMs. ARMs generate tokens via causal conditioning, whereas dLLMs generate sequences via a multi-step diffusion process. (b): dTRPO samples masked tokens for each block and estimates trajectory probability ratios using only the probabilities of newly unmasked tokens under $\pi_\theta$.
  • Figure 3: Ablation study of algorithm design and implementation choices; Inference speed comparison for ours and baseline models.
  • Figure 4: Causal attention and block attention in the training time. In (b), each row represents the tokens that an input token attends to. For a masked token, it attends to the masked input of the same block and clean tokens in the previous blocks.
  • Figure 5: Ablation on the parameter efficient finetuning and DPO hyperparameters.
  • ...and 1 more figures

Theorems & Definitions (8)

  • Theorem 3.1: State Reduction
  • Theorem 3.2: Ratio Reduction
  • proof
  • proof
  • Proposition A.1: Unbiasedness
  • proof
  • Proposition A.2: Variance
  • proof