Table of Contents
Fetching ...

Bringing Stability to Diffusion: Decomposing and Reducing Variance of Training Masked Diffusion Models

Mengni Jia, Mengyu Zhou, Yihao Liu, Xiaoxi Jiang, Guanjun Jiang

TL;DR

This work addresses the pronounced training variance of masked diffusion models (MDMs) by deriving the first systematic variance decomposition for MDM training, isolating three sources: masking pattern noise (A), masking rate noise (B), and data noise (C), with autoregressive models (ARMs) affected only by C. Building on this, the authors propose six variance-reduction techniques, centering on two core methods: P-POTS, which uses a Pareto-optimal nonuniform sampling for the masking rate t with an unbiased estimator that minimizes A+B+C, and MIRROR, which introduces negatively correlated masked samples to reduce A; they also introduce ISAD, SyRM, StraTS, and EMA as complementary strategies. Empirical results across text and image tasks show 7–8% accuracy gains on complex reasoning benchmarks and substantial reductions in run-to-run variability, narrowing the training gap with strong ARM baselines. The combination of theoretical variance control and robust empirical improvements demonstrates that stabilized, data-efficient MDM training is achievable, enabling MDMs to approach or surpass ARM performance in several settings. Overall, the paper provides a principled framework and practical toolkit for variance management in masked diffusion models, with broad implications for scalable, stable diffusion-based learning.

Abstract

Masked diffusion models (MDMs) are a promising alternative to autoregressive models (ARMs), but they suffer from inherently much higher training variance. High variance leads to noisier gradient estimates and unstable optimization, so even equally strong pretrained MDMs and ARMs that are competitive at initialization often diverge after task-specific training, with MDMs falling far behind. There has been no theoretical explanation or systematic solution. We derive the first decomposition of MDM training variance into three sources: (A) masking pattern noise, (B) masking rate noise, and (C) data noise, while ARMs are only affected by (C). This explains the fundamental training gap. Building on this foundation, we design six variance-reduction methods, including two core methods: (1) P-POTS, a Pareto-optimal t sampler that minimizes training variance by sampling harder t values more often with appropriately smaller update steps, and (2) MIRROR, which uses negatively correlated samples to reduce (A). Experiments show that compared to standard MDM training, our methods improve accuracy by 7-8% on complex reasoning tasks, while simultaneously reducing run-to-run variability to near ARM levels, substantially narrowing the gap with strong ARM baselines; in most settings, even the best baseline runs remain below the worst run of our method.

Bringing Stability to Diffusion: Decomposing and Reducing Variance of Training Masked Diffusion Models

TL;DR

This work addresses the pronounced training variance of masked diffusion models (MDMs) by deriving the first systematic variance decomposition for MDM training, isolating three sources: masking pattern noise (A), masking rate noise (B), and data noise (C), with autoregressive models (ARMs) affected only by C. Building on this, the authors propose six variance-reduction techniques, centering on two core methods: P-POTS, which uses a Pareto-optimal nonuniform sampling for the masking rate t with an unbiased estimator that minimizes A+B+C, and MIRROR, which introduces negatively correlated masked samples to reduce A; they also introduce ISAD, SyRM, StraTS, and EMA as complementary strategies. Empirical results across text and image tasks show 7–8% accuracy gains on complex reasoning benchmarks and substantial reductions in run-to-run variability, narrowing the training gap with strong ARM baselines. The combination of theoretical variance control and robust empirical improvements demonstrates that stabilized, data-efficient MDM training is achievable, enabling MDMs to approach or surpass ARM performance in several settings. Overall, the paper provides a principled framework and practical toolkit for variance management in masked diffusion models, with broad implications for scalable, stable diffusion-based learning.

Abstract

Masked diffusion models (MDMs) are a promising alternative to autoregressive models (ARMs), but they suffer from inherently much higher training variance. High variance leads to noisier gradient estimates and unstable optimization, so even equally strong pretrained MDMs and ARMs that are competitive at initialization often diverge after task-specific training, with MDMs falling far behind. There has been no theoretical explanation or systematic solution. We derive the first decomposition of MDM training variance into three sources: (A) masking pattern noise, (B) masking rate noise, and (C) data noise, while ARMs are only affected by (C). This explains the fundamental training gap. Building on this foundation, we design six variance-reduction methods, including two core methods: (1) P-POTS, a Pareto-optimal t sampler that minimizes training variance by sampling harder t values more often with appropriately smaller update steps, and (2) MIRROR, which uses negatively correlated samples to reduce (A). Experiments show that compared to standard MDM training, our methods improve accuracy by 7-8% on complex reasoning tasks, while simultaneously reducing run-to-run variability to near ARM levels, substantially narrowing the gap with strong ARM baselines; in most settings, even the best baseline runs remain below the worst run of our method.

Paper Structure

This paper contains 83 sections, 112 equations, 8 figures, 6 tables, 1 algorithm.

Figures (8)

  • Figure 1: Graphic illustration of three sources of training variance in MDMs. The left panel illustrates standard MDM training, compared to our core methods (P-POTS and MIRROR) on the right.
  • Figure 2: Empirical $\{p_j\}_{j=1}^b$ (scatter) and the fitted curve $p^*(t)=\sqrt{g(t)^2+v(t)}$ (line) on three datasets: (a) OpenScience, (b) GSM8K, and (c) HiTab. The equations in each subplot show the fitted forms of $g(t)$ and $v(t)$, which together characterize the sampling distribution across masking rates $t$.
  • Figure 3: Images generated by MMaDA-8B-MixCoT trained with P-POTS+MIRROR (top row) and the standard method (bottom row) under the same seed. The columns correspond to the prompts: 1) A woman with blonde hair is wearing a colorful crochet headband and a black top with a floral patterned shawl. She is holding a toy sword with a green handle and a yellow ball on top; 2) A rustic wooden farmhouse with a weathered roof and a small porch stands in a field of dry grass. The sky is filled with fluffy white clouds, and a single bird is flying in the distance. The landscape is devoid of any other structures or signs of life, giving the scene a serene and isolated feel; 3) A car is on fire under a bridge with smoke billowing out. The fire is intense, with flames visible on the car's underside. The bridge has a metal railing and is located near a forested area. There is a street lamp above the bridge, and the scene appears to be at night or during dusk; 4) A close-up of a grey fringed purse with a tassel detail on the front. The purse is placed on a surface with a soft, blurred background; more case studies can be found in \ref{['app:cases']}.
  • Figure 4: Training loss comparison between P-POTS+MIRROR (green) and Standard (red) on OpenScience (left) and text-to-image-2M (right). The annotations indicate the average loss at the beginning and end of training over the first and last $5$ steps. Overall, P-POTS+MIRROR achieves more stable loss trajectories with consistently lower end losses compared to the Standard baseline.
  • Figure 5: The left panel shows the heteroscedasticity of the $t$-loss, and the right panel illustrates the shrinking difference between IID and stratified sampling as batch size $B$ increases, where we use the KS statistic to measure their maximum deviation.
  • ...and 3 more figures