Table of Contents
Fetching ...

Demystifying Diffusion Objectives: Reweighted Losses are Better Variational Bounds

Jiaxin Shi, Michalis K. Titsias

TL;DR

The paper reframes diffusion-model training by deriving a cascade of time-dependent variational lower bounds on the data log-likelihood, revealing that reweighted losses can yield tighter bounds than the standard ELBO and reduce data-model KL divergences. It introduces the concept of optimal decoders to obtain improved ELBOs and proves that incorporating more optimal steps tightens the bound, while highlighting a tradeoff with sampling tractability. By showing that common reweighted objectives are equivalent to weighted sums of these improved bounds, the work provides a general theoretical justification for reweighted losses, extending from continuous Gaussian to masked diffusion models. The authors adapt these ideas to masked diffusion, deriving weighting schemes that respect the log-SNR parameterization and demonstrate substantial improvements in pixel-space ImageNet 64×64 generation (e.g., FID improvements up to 1.92 with 324M parameters). Overall, the work clarifies the theoretical basis for reweighted losses, demonstrates their applicability to discrete diffusion, and reports strong empirical gains in sample quality, suggesting practical impact for diffusion-based generative modeling.

Abstract

We derive a new theoretical interpretation of the reweighted losses that are widely used for training diffusion models. Our method is based on constructing a cascade of time-dependent variational lower bounds on the data log-likelihood, that provably improves upon the standard evidence lower bound and results in reduced data-model KL-divergences. Combining such bounds gives rise to reweighted objectives that can be applied to any generative diffusion model including both continuous Gaussian diffusion and masked (discrete) diffusion models. Then, we showcase this framework in masked diffusion and report significant improvements over previous training losses in pixel-space image modeling, approaching sample quality comparable to continuous diffusion models. Our results also provide a theoretical justification for the simple weighting scheme widely used in masked image models.

Demystifying Diffusion Objectives: Reweighted Losses are Better Variational Bounds

TL;DR

The paper reframes diffusion-model training by deriving a cascade of time-dependent variational lower bounds on the data log-likelihood, revealing that reweighted losses can yield tighter bounds than the standard ELBO and reduce data-model KL divergences. It introduces the concept of optimal decoders to obtain improved ELBOs and proves that incorporating more optimal steps tightens the bound, while highlighting a tradeoff with sampling tractability. By showing that common reweighted objectives are equivalent to weighted sums of these improved bounds, the work provides a general theoretical justification for reweighted losses, extending from continuous Gaussian to masked diffusion models. The authors adapt these ideas to masked diffusion, deriving weighting schemes that respect the log-SNR parameterization and demonstrate substantial improvements in pixel-space ImageNet 64×64 generation (e.g., FID improvements up to 1.92 with 324M parameters). Overall, the work clarifies the theoretical basis for reweighted losses, demonstrates their applicability to discrete diffusion, and reports strong empirical gains in sample quality, suggesting practical impact for diffusion-based generative modeling.

Abstract

We derive a new theoretical interpretation of the reweighted losses that are widely used for training diffusion models. Our method is based on constructing a cascade of time-dependent variational lower bounds on the data log-likelihood, that provably improves upon the standard evidence lower bound and results in reduced data-model KL-divergences. Combining such bounds gives rise to reweighted objectives that can be applied to any generative diffusion model including both continuous Gaussian diffusion and masked (discrete) diffusion models. Then, we showcase this framework in masked diffusion and report significant improvements over previous training losses in pixel-space image modeling, approaching sample quality comparable to continuous diffusion models. Our results also provide a theoretical justification for the simple weighting scheme widely used in masked image models.

Paper Structure

This paper contains 11 sections, 2 theorems, 24 equations, 6 figures, 4 tables.

Key Result

Theorem 1

For $\mathbf{x} \sim q(\mathbf{x})$, $\mathcal{L}^{(i + 1)}(\mathbf{x})$ is on average a better lower boundNote that we are comparing the lower bounds for slightly different model distributions (the generative model used in $\mathcal{L}^{(i)}$ has one more reverse transition parameterized by the den Since $\mathrm{KL}({q(\mathbf{x})}\|{p_\theta(\mathbf{x})}) = -\mathbb{E}_{q(\mathbf{x})}[\log p_\t

Figures (6)

  • Figure 1: Diffusion objectives viewed as a weighted sum of the ELBOs of a sequence of models with optimal decoders (defined in \ref{['sec:optimal-decoders']}). For continuous Gaussian diffusion models: $L_{\text{denoise}}(\mathbf{z}_t, \mathbf{x}, t) = \frac{1}{2} \lambda'(t) \|{\bm{\epsilon} - \bm{\epsilon}_\theta(\mathbf{z}_t, t)}\|_2^2$. For masked diffusion models: $L_{\text{denoise}}(\mathbf{z}_t, \mathbf{x}, t) = -\frac{\alpha_t'}{1 - \alpha_t} \delta_{\mathbf{z}_t, m}\cdot \mathbf{x}^\top \log \mu_\theta(\mathbf{z}_t)$.
  • Figure 2: Left: Weighting functions used in Gaussian diffusion models. Their formulas can be found in \ref{['tab:existing-ws']}. Right: Weighting functions for masked diffusion models, all except the simple weighting are matched from the $w(\lambda)$s of Gaussian diffusion for the cosine schedule $\alpha_t$. All functions are plotted between $[0, 0.999]$ and are normalized with their maximum values in this interval (note that Flow matching and simple weighting approaches infinity at $t=1$).
  • Figure 3: Total cross-entropy loss weight under cosine schedule $\alpha_t = 1 - \cos(\frac{\pi}{2}(1 - t))$.
  • Figure 4: Class-conditional samples generated in 256 steps by the masked diffusion model (324M) trained with the simple weighting on ImageNet 64$\times$64 (FID: 1.92). Each row shows samples conditioned on a unique class. We observed a strong diversity in each class, showing good coverage of the data distribution.
  • Figure 5: Class-conditional generation from masked diffusion models with 204M parameters trained with monotonic weighting functions (ELBO, Sigmoid, FM, Simple) on ImageNet 64$\times$64. Each image is conditioned on a unique class.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Theorem 1: Improved lower bounds
  • Theorem 2: Reweighted objectives as improved variational bounds