Table of Contents
Fetching ...

Diffusion Alignment as Variational Expectation-Maximization

Jaewoo Lee, Minsu Kim, Sanghyeok Choi, Inhyuck Song, Sujin Yun, Hyeongyu Kang, Woocheol Shin, Taeyoung Yun, Kiyoung Om, Jinkyoo Park

TL;DR

This work introduces Diffusion Alignment as Variational EM (DAV), a principled framework that alternates between an E-step of test-time search for reward-aligned, diverse trajectories and an M-step that distills these trajectories into the diffusion model via forward KL minimization. By modeling alignment as a variational inference problem with a discount factor, DAV achieves multi-modal reward alignment without succumbing to mode collapse or reward over-optimization. It is demonstrated on both continuous diffusion for text-to-image synthesis and discrete diffusion for DNA sequence design, showing improved reward metrics while preserving alignment, naturalness, and diversity. The approach is modular, extends to non-differentiable rewards, and offers a general pathway for robust downstream optimization of diffusion models in diverse domains.

Abstract

Diffusion alignment aims to optimize diffusion models for the downstream objective. While existing methods based on reinforcement learning or direct backpropagation achieve considerable success in maximizing rewards, they often suffer from reward over-optimization and mode collapse. We introduce Diffusion Alignment as Variational Expectation-Maximization (DAV), a framework that formulates diffusion alignment as an iterative process alternating between two complementary phases: the E-step and the M-step. In the E-step, we employ test-time search to generate diverse and reward-aligned samples. In the M-step, we refine the diffusion model using samples discovered by the E-step. We demonstrate that DAV can optimize reward while preserving diversity for both continuous and discrete tasks: text-to-image synthesis and DNA sequence design.

Diffusion Alignment as Variational Expectation-Maximization

TL;DR

This work introduces Diffusion Alignment as Variational EM (DAV), a principled framework that alternates between an E-step of test-time search for reward-aligned, diverse trajectories and an M-step that distills these trajectories into the diffusion model via forward KL minimization. By modeling alignment as a variational inference problem with a discount factor, DAV achieves multi-modal reward alignment without succumbing to mode collapse or reward over-optimization. It is demonstrated on both continuous diffusion for text-to-image synthesis and discrete diffusion for DNA sequence design, showing improved reward metrics while preserving alignment, naturalness, and diversity. The approach is modular, extends to non-differentiable rewards, and offers a general pathway for robust downstream optimization of diffusion models in diverse domains.

Abstract

Diffusion alignment aims to optimize diffusion models for the downstream objective. While existing methods based on reinforcement learning or direct backpropagation achieve considerable success in maximizing rewards, they often suffer from reward over-optimization and mode collapse. We introduce Diffusion Alignment as Variational Expectation-Maximization (DAV), a framework that formulates diffusion alignment as an iterative process alternating between two complementary phases: the E-step and the M-step. In the E-step, we employ test-time search to generate diverse and reward-aligned samples. In the M-step, we refine the diffusion model using samples discovered by the E-step. We demonstrate that DAV can optimize reward while preserving diversity for both continuous and discrete tasks: text-to-image synthesis and DNA sequence design.

Paper Structure

This paper contains 63 sections, 2 theorems, 63 equations, 11 figures, 2 tables, 1 algorithm.

Key Result

Proposition 1

Let $\gamma\in(0,1]$ be the discount factor. The likelihood of the optimality variable $\mathcal{O}$ admits the following lower bound:

Figures (11)

  • Figure 1: Conceptual illustration of DAV. DAV alternates between E-step, where trajectories are obtained via test-time search, and M-step, where the diffusion model parameters $\theta$ are updated by amortizing the posterior into the policy. By iterating these two steps, DAV progressively refines the diffusion model toward a multi-modal aligned distribution.
  • Figure 2: Training dynamics of our methods and baseline models, with performance marked every 10 epochs. All methods were trained for 100 epochs, except for DDPO, which was trained for 500 epochs. Our approaches successfully preserve alignment score and diversity compared to baselines.
  • Figure 3: Qualitative comparison of our methods with DAS, DRaFT-1, DDPO, and pretrained model. Results for our methods are reported after 100 epochs of fine-tuning. For DDPO and DRaFT, we report the last checkpoint prior to significant collapse. The optimization target is the aesthetic score schuhmann2022laion, with scores reported below each method name.
  • Figure 4: Comparison of ELBO and aesthetic score trends for DAV and its ablated baselines.
  • Figure 5: Performance comparison of DAV and baseline models. The x-axis represents the reward (Pred-Activity), while the y-axis shows (a) Diversity (Levenshtein Diversity), (b) Naturalness (3-mer Corr), and (c) Validity (ATAC-Acc).
  • ...and 6 more figures

Theorems & Definitions (3)

  • Proposition 1: Lower bound on likelihood of $\mathcal O$ with discount factor $\gamma$
  • Proposition 1: Lower bound on likelihood of $\mathcal O$ with discount factor $\gamma$
  • proof