Table of Contents
Fetching ...

Adding Conditional Control to Diffusion Models with Reinforcement Learning

Yulai Zhao, Masatoshi Uehara, Gabriele Scalia, Sunyuan Kung, Tommaso Biancalani, Sergey Levine, Ehsan Hajiramezanali

TL;DR

This work tackles the problem of adding new conditional controls to pre-trained diffusion models using offline data. It reframes conditioning as a reinforcement-learning task and introduces CTRL, an augmented-diffusion approach that learns a soft-optimal drift to sample from the target distribution p(x|c,y) while penalizing divergence from the pre-trained model. CTRL achieves greater sample efficiency by leveraging conditional independence to simplify offline-data requirements and by training a calibrated reward model p(y|x,c) rather than directly modeling p(x|c,y). The authors provide theoretical connections to classifier guidance, detail an implementable three-step algorithm, analyze potential error sources, and validate performance on both single-task compressibility conditioning and multi-task conditioning, demonstrating notable improvements over baselines.

Abstract

Diffusion models are powerful generative models that allow for precise control over the characteristics of the generated samples. While these diffusion models trained on large datasets have achieved success, there is often a need to introduce additional controls in downstream fine-tuning processes, treating these powerful models as pre-trained diffusion models. This work presents a novel method based on reinforcement learning (RL) to add such controls using an offline dataset comprising inputs and labels. We formulate this task as an RL problem, with the classifier learned from the offline dataset and the KL divergence against pre-trained models serving as the reward functions. Our method, $\textbf{CTRL}$ ($\textbf{C}$onditioning pre-$\textbf{T}$rained diffusion models with $\textbf{R}$einforcement $\textbf{L}$earning), produces soft-optimal policies that maximize the abovementioned reward functions. We formally demonstrate that our method enables sampling from the conditional distribution with additional controls during inference. Our RL-based approach offers several advantages over existing methods. Compared to classifier-free guidance, it improves sample efficiency and can greatly simplify dataset construction by leveraging conditional independence between the inputs and additional controls. Additionally, unlike classifier guidance, it eliminates the need to train classifiers from intermediate states to additional controls. The code is available at https://github.com/zhaoyl18/CTRL.

Adding Conditional Control to Diffusion Models with Reinforcement Learning

TL;DR

This work tackles the problem of adding new conditional controls to pre-trained diffusion models using offline data. It reframes conditioning as a reinforcement-learning task and introduces CTRL, an augmented-diffusion approach that learns a soft-optimal drift to sample from the target distribution p(x|c,y) while penalizing divergence from the pre-trained model. CTRL achieves greater sample efficiency by leveraging conditional independence to simplify offline-data requirements and by training a calibrated reward model p(y|x,c) rather than directly modeling p(x|c,y). The authors provide theoretical connections to classifier guidance, detail an implementable three-step algorithm, analyze potential error sources, and validate performance on both single-task compressibility conditioning and multi-task conditioning, demonstrating notable improvements over baselines.

Abstract

Diffusion models are powerful generative models that allow for precise control over the characteristics of the generated samples. While these diffusion models trained on large datasets have achieved success, there is often a need to introduce additional controls in downstream fine-tuning processes, treating these powerful models as pre-trained diffusion models. This work presents a novel method based on reinforcement learning (RL) to add such controls using an offline dataset comprising inputs and labels. We formulate this task as an RL problem, with the classifier learned from the offline dataset and the KL divergence against pre-trained models serving as the reward functions. Our method, (onditioning pre-rained diffusion models with einforcement earning), produces soft-optimal policies that maximize the abovementioned reward functions. We formally demonstrate that our method enables sampling from the conditional distribution with additional controls during inference. Our RL-based approach offers several advantages over existing methods. Compared to classifier-free guidance, it improves sample efficiency and can greatly simplify dataset construction by leveraging conditional independence between the inputs and additional controls. Additionally, unlike classifier guidance, it eliminates the need to train classifiers from intermediate states to additional controls. The code is available at https://github.com/zhaoyl18/CTRL.
Paper Structure (66 sections, 5 theorems, 44 equations, 5 figures, 5 tables, 2 algorithms)

This paper contains 66 sections, 5 theorems, 44 equations, 5 figures, 5 tables, 2 algorithms.

Key Result

Lemma 1

For any $c \in \mathcal{C}$ and $y \in \mathcal{Y}$, by evolving according to the following SDE from $0$ to $T$: the marginal distribution of $x_T$, i.e., $p(x_T|c,y)$, is equal to the target distribution $p_{\gamma=1}(\cdot|c,y)$eq:target_distribution. Here, $\mathbb{P}^{\mathrm{pre}}$ denotes the distribution induced by the pre-trained diffusion model eq:pre_trained.

Figures (5)

  • Figure 1: Results for conditioning on compressibility. Figure \ref{['fig:comp_hist']} plots the histogram of samples generated by the pre-trained diffusion model. Figure \ref{['fig:training_curves']} shows the mean compressibility curves during fine-tuning with four distinct lines representing each condition. It is evident that CTRL effectively aligns the generated samples with their target compressibility levels via fine-tuning. Table \ref{['tab:scores']}, \ref{['tab:comparison']} provide evaluation metrics, and Figure \ref{['fig:images']} shows images generated by a single model fine-tuned with CTRL.
  • Figure 2: Results for multi-task conditioning. Figure \ref{['fig:multitask_hist']} plots the histogram of samples generated by the pre-trained diffusion model. Table \ref{['fig:multitask_scores']} presents the evaluation statistics. Figure \ref{['fig:multitask_images']} displays images generated by a single model fine-tuned with CTRL.
  • Figure 3: Confusion matrix for Reconstruction Guidance.
  • Figure 4: More images generated by CTRL in the compressibility task.
  • Figure 5: More images generated by CTRL in the multi-task conditional generation.

Theorems & Definitions (11)

  • Remark 1
  • Lemma 1: Doob's h-transforms rogers2000diffusions
  • Theorem 1: Conditioning as RL
  • Remark 2: Using classifier-free guidance to adjust guidance strength
  • Remark 3: Choice of exploratory distribution $\Pi$
  • Lemma 2: Bridging RL-based conditioning with classifier guidance
  • Example 1: Scenario $Y \perp C | X$
  • Example 2: Scenario $Y_1 \perp Y_2 | X, C$
  • Remark 4: PPO
  • Lemma 3: KL-constrained reward
  • ...and 1 more