Table of Contents
Fetching ...

Learning Diffusion Models with Flexible Representation Guidance

Chenyu Wang, Cai Zhou, Sharut Gupta, Zongyu Lin, Stefanie Jegelka, Stephen Bates, Tommi Jaakkola

TL;DR

This work provides a principled, variational framework for guiding diffusion models with pretrained representations, introducing a tunable, multi-latent approach that unifies and extends prior methods. The key contributions are a variational bound incorporating representation variables, a multi-latent extension, and the REED framework comprising multimodal data synthesis and curriculum-based training. Empirically, REED delivers substantial training speedups and quality gains across image, protein, and molecule generation, including dramatic improvements on ImageNet and accelerated protein inverse folding and 3D molecule generation. The methods advance data efficiency and cross-domain applicability of diffusion models, with practical implications for fast, high-quality generative modeling in complex domains.

Abstract

Diffusion models can be improved with additional guidance towards more effective representations of input. Indeed, prior empirical work has already shown that aligning internal representations of the diffusion model with those of pre-trained models improves generation quality. In this paper, we present a systematic framework for incorporating representation guidance into diffusion models. We provide alternative decompositions of denoising models along with their associated training criteria, where the decompositions determine when and how the auxiliary representations are incorporated. Guided by our theoretical insights, we introduce two new strategies for enhancing representation alignment in diffusion models. First, we pair examples with target representations either derived from themselves or arisen from different synthetic modalities, and subsequently learn a joint model over the multimodal pairs. Second, we design an optimal training curriculum that balances representation learning and data generation. Our experiments across image, protein sequence, and molecule generation tasks demonstrate superior performance as well as accelerated training. In particular, on the class-conditional ImageNet $256\times 256$ benchmark, our guidance results in $23.3$ times faster training than the original SiT-XL as well as four times speedup over the state-of-the-art method REPA. The code is available at https://github.com/ChenyuWang-Monica/REED.

Learning Diffusion Models with Flexible Representation Guidance

TL;DR

This work provides a principled, variational framework for guiding diffusion models with pretrained representations, introducing a tunable, multi-latent approach that unifies and extends prior methods. The key contributions are a variational bound incorporating representation variables, a multi-latent extension, and the REED framework comprising multimodal data synthesis and curriculum-based training. Empirically, REED delivers substantial training speedups and quality gains across image, protein, and molecule generation, including dramatic improvements on ImageNet and accelerated protein inverse folding and 3D molecule generation. The methods advance data efficiency and cross-domain applicability of diffusion models, with practical implications for fast, high-quality generative modeling in complex domains.

Abstract

Diffusion models can be improved with additional guidance towards more effective representations of input. Indeed, prior empirical work has already shown that aligning internal representations of the diffusion model with those of pre-trained models improves generation quality. In this paper, we present a systematic framework for incorporating representation guidance into diffusion models. We provide alternative decompositions of denoising models along with their associated training criteria, where the decompositions determine when and how the auxiliary representations are incorporated. Guided by our theoretical insights, we introduce two new strategies for enhancing representation alignment in diffusion models. First, we pair examples with target representations either derived from themselves or arisen from different synthetic modalities, and subsequently learn a joint model over the multimodal pairs. Second, we design an optimal training curriculum that balances representation learning and data generation. Our experiments across image, protein sequence, and molecule generation tasks demonstrate superior performance as well as accelerated training. In particular, on the class-conditional ImageNet benchmark, our guidance results in times faster training than the original SiT-XL as well as four times speedup over the state-of-the-art method REPA. The code is available at https://github.com/ChenyuWang-Monica/REED.

Paper Structure

This paper contains 66 sections, 5 theorems, 59 equations, 9 figures, 14 tables.

Key Result

Proposition 1

Let $\{\alpha_t\geq 0\}_{t=1}^T$ be a set of weights summing to one, and define $A_t \in [0,1]:=\sum_{i=t}^T \alpha_i$. Then, the variational bound $\mathcal{L}_{\text{VB}}^z$ in eq:bound can be written as: where $\tilde{p}_{\theta}(x_{t-1}|x_t,z;A_t)$ and its normalization $Z_t(x_t,z;A_t)$ are defined as:

Figures (9)

  • Figure 1: REED achieves superior performance and accelerated training on image generation and protein inverse folding.
  • Figure 2: We use synthetic auxiliary data modalities and multimodal representations to enhance representation alignment in diffusion model training. Shown are examples from class-conditional image generation (left), protein inverse folding (top tight), and molecule generation (bottom right).
  • Figure 3: Examples of images with their ground-truth class labels and the generated captions.
  • Figure 4: Selected samples on ImageNet $256\times256$ generated by the SiT-XL/2+REED model after 1M training iterations. We use classifier-free guidance with $w=1.275$.
  • Figure 5: Selected samples on ImageNet $256\times256$ generated by the SiT-XL/2+REED model after 1M training iterations. We use classifier-free guidance with $w=4.0$. Each row corresponds to the same class label. From top to bottom, the class labels are: "macaw" (88), "flamingo" (130), "borzoi, Russian wolfhound" (169), "Samoyed" (258), "Egyptian cat" (285), "otter" (360), and "dogsled" (537), respectively.
  • ...and 4 more figures

Theorems & Definitions (11)

  • Remark 1
  • Remark 2
  • Proposition 1: Decomposition Structure of the Variational Bound
  • Proposition 1: Decomposition Structure of the Variational Bound
  • proof
  • Proposition 2: Bounds on the Log-Normalization Term
  • proof
  • Proposition 3: Score Function of the Hybrid Distribution
  • proof
  • Theorem 1
  • ...and 1 more