Table of Contents
Fetching ...

Rethinking Losses for Diffusion Bridge Samplers

Sebastian Sanokowski, Lukas Gruber, Christoph Bartmann, Sepp Hochreiter, Sebastian Lehner

TL;DR

This work reevaluates loss functions for diffusion-bridge samplers targeting unnormalized distributions and shows that the LV loss often yields instability, especially when forward diffusion is learned and DPIs do not hold. The authors advocate training with the reverse KL divergence using the log-derivative trick (rKL-LD), augmented by learnable diffusion coefficients, and demonstrate superior, more stable performance across diverse diffusion-bridge architectures and benchmarks. The findings highlight practical guidance for diffusion-bridge training and point to future work on reducing mode-collapse in KL-based objectives. Overall, rKL-LD with learnable diffusion terms provides a robust, effective framework for diffusion-bridge sampling in challenging, high-dimensional tasks.

Abstract

Diffusion bridges are a promising class of deep-learning methods for sampling from unnormalized distributions. Recent works show that the Log Variance (LV) loss consistently outperforms the reverse Kullback-Leibler (rKL) loss when using the reparametrization trick to compute rKL-gradients. While the on-policy LV loss yields identical gradients to the rKL loss when combined with the log-derivative trick for diffusion samplers with non-learnable forward processes, this equivalence does not hold for diffusion bridges or when diffusion coefficients are learned. Based on this insight we argue that for diffusion bridges the LV loss does not represent an optimization objective that can be motivated like the rKL loss via the data processing inequality. Our analysis shows that employing the rKL loss with the log-derivative trick (rKL-LD) does not only avoid these conceptual problems but also consistently outperforms the LV loss. Experimental results with different types of diffusion bridges on challenging benchmarks show that samplers trained with the rKL-LD loss achieve better performance. From a practical perspective we find that rKL-LD requires significantly less hyperparameter optimization and yields more stable training behavior.

Rethinking Losses for Diffusion Bridge Samplers

TL;DR

This work reevaluates loss functions for diffusion-bridge samplers targeting unnormalized distributions and shows that the LV loss often yields instability, especially when forward diffusion is learned and DPIs do not hold. The authors advocate training with the reverse KL divergence using the log-derivative trick (rKL-LD), augmented by learnable diffusion coefficients, and demonstrate superior, more stable performance across diverse diffusion-bridge architectures and benchmarks. The findings highlight practical guidance for diffusion-bridge training and point to future work on reducing mode-collapse in KL-based objectives. Overall, rKL-LD with learnable diffusion terms provides a robust, effective framework for diffusion-bridge sampling in challenging, high-dimensional tasks.

Abstract

Diffusion bridges are a promising class of deep-learning methods for sampling from unnormalized distributions. Recent works show that the Log Variance (LV) loss consistently outperforms the reverse Kullback-Leibler (rKL) loss when using the reparametrization trick to compute rKL-gradients. While the on-policy LV loss yields identical gradients to the rKL loss when combined with the log-derivative trick for diffusion samplers with non-learnable forward processes, this equivalence does not hold for diffusion bridges or when diffusion coefficients are learned. Based on this insight we argue that for diffusion bridges the LV loss does not represent an optimization objective that can be motivated like the rKL loss via the data processing inequality. Our analysis shows that employing the rKL loss with the log-derivative trick (rKL-LD) does not only avoid these conceptual problems but also consistently outperforms the LV loss. Experimental results with different types of diffusion bridges on challenging benchmarks show that samplers trained with the rKL-LD loss achieve better performance. From a practical perspective we find that rKL-LD requires significantly less hyperparameter optimization and yields more stable training behavior.

Paper Structure

This paper contains 54 sections, 64 equations, 1 figure, 5 tables, 3 algorithms.

Figures (1)

  • Figure 1: Models with fixed $\sigma_{\mathrm{diff}}$ are marked with $\Large \star$. Left: Training curves on the Brownian task of CMCD trained with LV loss and rKL-LD loss. Middle: Plot of the learned $\sigma_{\mathrm{diff}}$ in ascending order of the CMCD-rKL-LD run from the left figure. Right: Training curves on the Seeds task, where CMCD-rKL-LD $\sigma_{\mathrm{diff, init}}$ is compared to CMCD-rKL-LD $\Large \star$$\sigma_{\mathrm{diff, init}}$ at different initializations of $\sigma_{\mathrm{diff}}$.