Table of Contents
Fetching ...

Efficient and Stable Reinforcement Learning for Diffusion Language Models

Jiawei Liu, Xiting Wang, Yuanyuan Zhong, Defu Lian, Yu Yang

TL;DR

This work tackles the dual challenges of efficiency and stability in reinforcement learning for diffusion-based language models by introducing Spatio-Temporal Pruning (STP). STP combines spatial pruning, which fixes a fraction of tokens using static priors to constrain exploration, with temporal pruning, which omits late-stage denoising steps to reduce computation. The authors provide theoretical guarantees showing that STP reduces ELBO variance and stabilizes GRPO-based training, and they validate these claims with extensive experiments on math and logic benchmarks, achieving up to 81.7% relative improvements in logic tasks and notable training speedups. Importantly, STP is orthogonal to other RL advances and demonstrated to be compatible with alternative RL algorithms, acting as a versatile plug-in to accelerate and stabilize diffusion-based RL for first-pass reasoning tasks.

Abstract

Reinforcement Learning (RL) is crucial for unlocking the complex reasoning capabilities of Diffusion-based Large Language Models (dLLMs). However, applying RL to dLLMs faces unique challenges in efficiency and stability. To address these challenges, we propose Spatio-Temporal Pruning (STP), a framework designed to simultaneously improve the efficiency and stability of RL for dLLMs. STP compresses the redundancy in the generative process through: (1) \textit{spatial pruning}, which constrains the exploration space using static priors; and (2) \textit{temporal pruning}, which bypasses redundant late-stage refinement steps. Our theoretical analysis demonstrates that STP strictly reduces the variance of the log-likelihood estimation, thereby ensuring more stable policy updates. Extensive experiments demonstrate that STP surpasses state-of-the-art baselines in both efficiency and accuracy. Our code is available at https://github.com/Lolo1222/STP.

Efficient and Stable Reinforcement Learning for Diffusion Language Models

TL;DR

This work tackles the dual challenges of efficiency and stability in reinforcement learning for diffusion-based language models by introducing Spatio-Temporal Pruning (STP). STP combines spatial pruning, which fixes a fraction of tokens using static priors to constrain exploration, with temporal pruning, which omits late-stage denoising steps to reduce computation. The authors provide theoretical guarantees showing that STP reduces ELBO variance and stabilizes GRPO-based training, and they validate these claims with extensive experiments on math and logic benchmarks, achieving up to 81.7% relative improvements in logic tasks and notable training speedups. Importantly, STP is orthogonal to other RL advances and demonstrated to be compatible with alternative RL algorithms, acting as a versatile plug-in to accelerate and stabilize diffusion-based RL for first-pass reasoning tasks.

Abstract

Reinforcement Learning (RL) is crucial for unlocking the complex reasoning capabilities of Diffusion-based Large Language Models (dLLMs). However, applying RL to dLLMs faces unique challenges in efficiency and stability. To address these challenges, we propose Spatio-Temporal Pruning (STP), a framework designed to simultaneously improve the efficiency and stability of RL for dLLMs. STP compresses the redundancy in the generative process through: (1) \textit{spatial pruning}, which constrains the exploration space using static priors; and (2) \textit{temporal pruning}, which bypasses redundant late-stage refinement steps. Our theoretical analysis demonstrates that STP strictly reduces the variance of the log-likelihood estimation, thereby ensuring more stable policy updates. Extensive experiments demonstrate that STP surpasses state-of-the-art baselines in both efficiency and accuracy. Our code is available at https://github.com/Lolo1222/STP.
Paper Structure (36 sections, 5 theorems, 28 equations, 3 figures, 8 tables)

This paper contains 36 sections, 5 theorems, 28 equations, 3 figures, 8 tables.

Key Result

Theorem 4.1

(Computational Complexity Reduction) Let $N$ be the total number of diffusion steps, $\gamma$ be the spatial pruning ratio, and $t_{\text{cutoff}} \in (0, 1)$ be the temporal pruning cutoff. The standard sampling cost is $\mathcal{C}_{\text{std}} = N \cdot \mathcal{C}_{\text{step}}$. Under STP, the

Figures (3)

  • Figure 1: Comparing (a) standard sequence generation with iterative denoising with (b) our generation process with spatio-temporal purning, which reduces computational cost and ELBO variance theoretically while improving empirical accuracy.
  • Figure 2: Replacing standard sampling with STP accelerates trajectory generation for exploration. Furthermore, by constraining the sampling space, STP yields lower-variance ELBO estimates, facilitating stable policy updates.
  • Figure 3: Empirical Verification of Variance Reduction. (a) The ELBO estimation variance recorded during training dynamics. STP (Orange) consistently exhibits lower variance compared to the standard method GRPO w/ ELBO (Blue). (b) The aggregate distribution of variance shows that STP significantly lowers the median variance and suppresses extreme outliers, validating our theoretical bounds in Theorem \ref{['thm:variance_reduction']}.

Theorems & Definitions (9)

  • Theorem 4.1
  • Theorem 4.3
  • Theorem 4.4
  • proof
  • Lemma 1.1
  • proof
  • Lemma 1.2
  • proof
  • proof