Table of Contents
Fetching ...

Principled and Tractable RL for Reasoning with Diffusion Language Models

Anthony Zhan

TL;DR

This paper tackles the challenge of applying reinforcement learning post-training to diffusion language models (dLLMs), where standard policy-gradient methods are not directly compatible with the diffusion denoising process. It introduces Amortized GRPO (AGRPO), a principled on-policy RL algorithm that computes unbiased policy gradients for dLLMs via Monte Carlo sampling over timesteps, avoiding biased one-step approximations. AGRPO also incorporates low-discrepancy sampling and practical design choices (caching, k-tuning) to make online RL feasible for diffusion models. Empirically, AGRPO improves reasoning benchmarks (GSM8K, MATH, Countdown) relative to both diffusion baselines and autoregressive RL methods, and it expands the compute/quality frontier by achieving higher accuracy with fewer sampling steps, highlighting its practical impact for efficient, robust reasoning with dLLMs.

Abstract

Diffusion large language models (dLLMs) are a new paradigm of non-autoregressive language models that are trained to predict multiple tokens in parallel and generate text via iterative unmasking. Recent works have successfully pretrained dLLMs to parity with autoregressive LLMs at the 8B scale, but dLLMs have yet to benefit from modern post-training techniques, e.g. reinforcement learning (RL), that have proven effective for autoregressive models. Crucially, algorithms designed for traditional LLMs aren't directly compatible with diffusion frameworks due to inherent differences in modeling assumptions. Moreover, existing attempts at dLLM post-training with RL rely on heuristic-based objectives with no theoretical grounding. In this work, we present Amortized Group Relative Policy Optimization (AGRPO), a principled on-policy RL algorithm designed specifically for dLLMs. AGRPO uses Monte Carlo sampling to compute an unbiased policy gradient estimate, making it the first tractable, faithful adaptation of policy gradient methods for dLLMs. We demonstrate AGRPO's effectiveness on different math/reasoning tasks, a common setting for RL with LLMs, achieving up to +7.6% absolute gain on GSM8K and 3.8x performance on the Countdown task over the baseline LLaDA-8B-Instruct model and 1.3x performance gains over comparable RL methods such as diffu-GRPO. Furthermore, these gains persist across different numbers of sampling steps at inference time, achieving better tradeoffs between compute and performance. Our results demonstrate that online RL algorithms can be extended to diffusion LLMs in principled ways, maintaining both theoretical soundness and practical effectiveness.

Principled and Tractable RL for Reasoning with Diffusion Language Models

TL;DR

This paper tackles the challenge of applying reinforcement learning post-training to diffusion language models (dLLMs), where standard policy-gradient methods are not directly compatible with the diffusion denoising process. It introduces Amortized GRPO (AGRPO), a principled on-policy RL algorithm that computes unbiased policy gradients for dLLMs via Monte Carlo sampling over timesteps, avoiding biased one-step approximations. AGRPO also incorporates low-discrepancy sampling and practical design choices (caching, k-tuning) to make online RL feasible for diffusion models. Empirically, AGRPO improves reasoning benchmarks (GSM8K, MATH, Countdown) relative to both diffusion baselines and autoregressive RL methods, and it expands the compute/quality frontier by achieving higher accuracy with fewer sampling steps, highlighting its practical impact for efficient, robust reasoning with dLLMs.

Abstract

Diffusion large language models (dLLMs) are a new paradigm of non-autoregressive language models that are trained to predict multiple tokens in parallel and generate text via iterative unmasking. Recent works have successfully pretrained dLLMs to parity with autoregressive LLMs at the 8B scale, but dLLMs have yet to benefit from modern post-training techniques, e.g. reinforcement learning (RL), that have proven effective for autoregressive models. Crucially, algorithms designed for traditional LLMs aren't directly compatible with diffusion frameworks due to inherent differences in modeling assumptions. Moreover, existing attempts at dLLM post-training with RL rely on heuristic-based objectives with no theoretical grounding. In this work, we present Amortized Group Relative Policy Optimization (AGRPO), a principled on-policy RL algorithm designed specifically for dLLMs. AGRPO uses Monte Carlo sampling to compute an unbiased policy gradient estimate, making it the first tractable, faithful adaptation of policy gradient methods for dLLMs. We demonstrate AGRPO's effectiveness on different math/reasoning tasks, a common setting for RL with LLMs, achieving up to +7.6% absolute gain on GSM8K and 3.8x performance on the Countdown task over the baseline LLaDA-8B-Instruct model and 1.3x performance gains over comparable RL methods such as diffu-GRPO. Furthermore, these gains persist across different numbers of sampling steps at inference time, achieving better tradeoffs between compute and performance. Our results demonstrate that online RL algorithms can be extended to diffusion LLMs in principled ways, maintaining both theoretical soundness and practical effectiveness.

Paper Structure

This paper contains 28 sections, 6 equations, 3 figures, 2 tables, 1 algorithm.

Figures (3)

  • Figure 1: A comparison of different RL post-training algorithms for dLLMs. Existing algorithms designed for traditional LLMs such as GRPO require token-level probabilities, which would entail $O(\text{response length})$ forward passes. Current tractable techniques for dLLMs involve heuristic approximations, resulting in biased policy gradients. Our proposed algorithm takes a different approach that remains tractable and faithful to the original GRPO objective.
  • Figure 2: Models under the autoregressive and diffusion paradigms are trained on different objectives. Next-token prediction (left) is a narrower, easily parallelizable task, whereas masked token prediction (right) is harder and less conducive to parallelism since it involves predicting multiple tokens. Diffusion models must also optimize for a lower bound rather than the exact likelihood.
  • Figure 3: The inference compute/quality frontier for GSM8K across different configurations. Lines connect comparable points, i.e. same model and response length, showing the possible tradeoffs at inference time. Models trained with AGRPO consistently outperform baselines and additionally retain quality with fewer sampling steps.

Theorems & Definitions (3)

  • Remark
  • Remark
  • Remark