Table of Contents
Fetching ...

Iterative Importance Fine-tuning of Diffusion Models

Alexander Denker, Shreyas Padhy, Francisco Vargas, Johannes Hertrich

TL;DR

This work treats downstream tasks for diffusion models as sampling from a tilted distribution $p_\text{tilted}(\mathbf{x}) \propto p_\text{data}(\mathbf{x}) \exp(r(\mathbf{x})/\lambda)$ and leverages Doob's $h$-transform to enable conditional sampling. It proposes a self-supervised, amortised fine-tuning framework (SIFT) that iteratively samples trajectories with a current control, uses path-based importance weights to filter to approximate the tilted distribution, and updates the control via a score-matching objective, with proven descent of the stochastic control free-energy. The method is validated on MNIST class-conditional sampling, inverse-problem posterior sampling (e.g., super-resolution), and text-to-image reward fine-tuning, showing competitive performance with efficient memory use and applicability to large models without backpropagating through the generation process. The approach provides a principled, scalable alternative to online RL-based fine-tuning, balancing fidelity to rewards with diversity and computational tractability, and offering practical benefits for personalized or task-specific diffusion-model deployment.

Abstract

Diffusion models are an important tool for generative modelling, serving as effective priors in applications such as imaging and protein design. A key challenge in applying diffusion models for downstream tasks is efficiently sampling from resulting posterior distributions, which can be addressed using Doob's $h$-transform. This work introduces a self-supervised algorithm for fine-tuning diffusion models by learning the optimal control, enabling amortised conditional sampling. Our method iteratively refines the control using a synthetic dataset resampled with path-based importance weights. We demonstrate the effectiveness of this framework on class-conditional sampling, inverse problems and reward fine-tuning for text-to-image diffusion models.

Iterative Importance Fine-tuning of Diffusion Models

TL;DR

This work treats downstream tasks for diffusion models as sampling from a tilted distribution and leverages Doob's -transform to enable conditional sampling. It proposes a self-supervised, amortised fine-tuning framework (SIFT) that iteratively samples trajectories with a current control, uses path-based importance weights to filter to approximate the tilted distribution, and updates the control via a score-matching objective, with proven descent of the stochastic control free-energy. The method is validated on MNIST class-conditional sampling, inverse-problem posterior sampling (e.g., super-resolution), and text-to-image reward fine-tuning, showing competitive performance with efficient memory use and applicability to large models without backpropagating through the generation process. The approach provides a principled, scalable alternative to online RL-based fine-tuning, balancing fidelity to rewards with diversity and computational tractability, and offering practical benefits for personalized or task-specific diffusion-model deployment.

Abstract

Diffusion models are an important tool for generative modelling, serving as effective priors in applications such as imaging and protein design. A key challenge in applying diffusion models for downstream tasks is efficiently sampling from resulting posterior distributions, which can be addressed using Doob's -transform. This work introduces a self-supervised algorithm for fine-tuning diffusion models by learning the optimal control, enabling amortised conditional sampling. Our method iteratively refines the control using a synthetic dataset resampled with path-based importance weights. We demonstrate the effectiveness of this framework on class-conditional sampling, inverse problems and reward fine-tuning for text-to-image diffusion models.

Paper Structure

This paper contains 37 sections, 8 theorems, 54 equations, 13 figures, 4 tables.

Key Result

Lemma 2.1

Assume that ${\mathbb{P}}$ and ${\mathbb{Q}}$ the path measures of the solution for the same SDE ${\mathrm{d}} {\bm{Y}}_t = f_t({\bm{Y}}_t) \,{\mathrm{d}} t + \sigma_t \begin{tikzpicture}[baseline=(char.base)]{ \node[inner sep=0pt, outer sep=0pt] (char) {$ \dd \rv{W}$}; \draw[line width=0.

Figures (13)

  • Figure 1: GMM prior and several posterior configurations: Diagonal, Corners, Cross, Ring and Checkerboard.
  • Figure 1: Additional results for different images from the test set of the Flowers dataset.
  • Figure 2: Evolution of the fine-tuned model across outer iterations. Left: energy distance as a function of iteration.
  • Figure 2: Reward and diversity for text-to-image fine-tuning on "A green colored rabbit." during training. As known for LoRA fine-tuning the diversity decreases over iterations.
  • Figure 3: Class-conditional sampling for MNIST using the smooth reward function: prior, zero, two, four, six, eight.
  • ...and 8 more figures

Theorems & Definitions (15)

  • Lemma 2.1
  • Proof 1
  • Theorem 2.2: Variational Principle for Tilted Path Measures
  • Proof 2
  • Proposition 2.3: Score Decomposition
  • Proof 3
  • Theorem 2.4: Loss functions for Computing $u_t^*$
  • Corollary 3.1
  • Proof 4
  • Lemma 3.2
  • ...and 5 more