Table of Contents
Fetching ...

Direct Distributional Optimization for Provable Alignment of Diffusion Models

Ryotaro Kawata, Kazusato Oko, Atsushi Nitanda, Taiji Suzuki

TL;DR

This work tackles the challenge of aligning diffusion models when output densities are inaccessible and multimodal, by formulating alignment as a nonlinear distributional optimization problem and solving it with a framework that combines Dual Averaging (DA) with Doob's $h$-transform. The key contributions are convergence guarantees for both convex and nonconvex objectives, an end-to-end bound on sampling error, and a practical sampling strategy that relies only on samples from the pretrained reference model and a learned density-ratio potential. The method generalizes to RLHF, Direct Preference Optimization (DPO), and Kahneman-Tversky Optimization (KTO), and is validated through synthetic and image experiments where it achieves true objective improvement over diffusion-based baselines. By avoiding isoperimetric conditions and enabling sampling from multimodal targets, this approach offers a scalable, theory-backed route to distributional alignment with potential wide-ranging applications beyond diffusion models.

Abstract

We introduce a novel alignment method for diffusion models from distribution optimization perspectives while providing rigorous convergence guarantees. We first formulate the problem as a generic regularized loss minimization over probability distributions and directly optimize the distribution using the Dual Averaging method. Next, we enable sampling from the learned distribution by approximating its score function via Doob's $h$-transform technique. The proposed framework is supported by rigorous convergence guarantees and an end-to-end bound on the sampling error, which imply that when the original distribution's score is known accurately, the complexity of sampling from shifted distributions is independent of isoperimetric conditions. This framework is broadly applicable to general distribution optimization problems, including alignment tasks in Reinforcement Learning with Human Feedback (RLHF), Direct Preference Optimization (DPO), and Kahneman-Tversky Optimization (KTO). We empirically validate its performance on synthetic and image datasets using the DPO objective.

Direct Distributional Optimization for Provable Alignment of Diffusion Models

TL;DR

This work tackles the challenge of aligning diffusion models when output densities are inaccessible and multimodal, by formulating alignment as a nonlinear distributional optimization problem and solving it with a framework that combines Dual Averaging (DA) with Doob's -transform. The key contributions are convergence guarantees for both convex and nonconvex objectives, an end-to-end bound on sampling error, and a practical sampling strategy that relies only on samples from the pretrained reference model and a learned density-ratio potential. The method generalizes to RLHF, Direct Preference Optimization (DPO), and Kahneman-Tversky Optimization (KTO), and is validated through synthetic and image experiments where it achieves true objective improvement over diffusion-based baselines. By avoiding isoperimetric conditions and enabling sampling from multimodal targets, this approach offers a scalable, theory-backed route to distributional alignment with potential wide-ranging applications beyond diffusion models.

Abstract

We introduce a novel alignment method for diffusion models from distribution optimization perspectives while providing rigorous convergence guarantees. We first formulate the problem as a generic regularized loss minimization over probability distributions and directly optimize the distribution using the Dual Averaging method. Next, we enable sampling from the learned distribution by approximating its score function via Doob's -transform technique. The proposed framework is supported by rigorous convergence guarantees and an end-to-end bound on the sampling error, which imply that when the original distribution's score is known accurately, the complexity of sampling from shifted distributions is independent of isoperimetric conditions. This framework is broadly applicable to general distribution optimization problems, including alignment tasks in Reinforcement Learning with Human Feedback (RLHF), Direct Preference Optimization (DPO), and Kahneman-Tversky Optimization (KTO). We empirically validate its performance on synthetic and image datasets using the DPO objective.

Paper Structure

This paper contains 33 sections, 25 theorems, 196 equations, 12 figures, 1 table, 3 algorithms.

Key Result

Theorem 1

Suppose that $\beta' \geq \beta$ and we train the potential $f_{k+1}$ so that it is sufficiently close to $\bar{g}^{(k)}$ as $\mathrm{TV}(\hat{q}^{(k)},q^{(k)}) \leq \epsilon_\mathrm{TV}$ for all $k$. Then, under Assumption ass:ConvexF, Option 1 satisfies the following convergence guarantee:

Figures (12)

  • Figure 1: Overview of the proposed method integrating Dual Averaging and Doob’s h-transform.
  • Figure 2: Left and Middle. The smoothed loss during optimization for Gaussian Mixture Model in Diffusion-DPO with/without regularization (left) and ours (middle). "True Objective": the true DPO loss rafailov2023DPO whose target point was $\mu_w = [2.5, 0]$. "Upperbound": An approximate upperbound of "Objective" optimized by Diffusion-DPO Wallace2024DiffusionDPO. Right. Aligned samples by Doob's h-transform. "Ref." represents the empirical density of $p_{\rm{ref}}$.
  • Figure 3: Left. Examples of aligned image generation. Our goal was to generate light-colored butterflies. "iter=2": ours with $k=2$ DA iterations, "iter=2": ours with $k=1$ DA iteration. "Ref.": samples from $p_{\rm{ref}}$. Right. Tilt-corrected Head CT image generation. "iter=3": ours with $k=3$ DA iterations, "iter=1": ours with $k=1$ DA iteration. "Ref,": samples from $p_{\rm{ref}}$.
  • Figure 4: Left. Pre-train MSE loss of denoising score matching. The minimum losses until the current epoch were plotted. Right. The histogram of 20000 samples from the pre-trained DDPM.
  • Figure 5: (For Reference) Losses without smoothing corresponding to Figure \ref{['fig:da-summary']}. "D-DPO": Diffusion-DPO. "Ours": Proposed DA.
  • ...and 7 more figures

Theorems & Definitions (47)

  • Example 1: Reinforcement Learning
  • Example 2: DPO
  • Example 3: KTO
  • Theorem 1: Convergence of the objective in Option 1
  • Theorem 2
  • Corollary 1: Convergence in Option 2
  • Theorem 3
  • Theorem 4
  • Theorem : restated - Theorem \ref{['thm-da-nishikawa']}
  • Lemma 1: NEURIPS2021_a34e1ddb
  • ...and 37 more