Table of Contents
Fetching ...

Maximum Entropy Inverse Reinforcement Learning of Diffusion Models with Energy-Based Models

Sangwoong Yoon, Himchan Hwang, Dohyun Kwon, Yung-Kyun Noh, Frank C. Park

TL;DR

DxMI enables the training of an EBM without MCMC, stabilizing EBM training dynamics and enhancing anomaly detection performance, and proposes Diffusion by Dynamic Programming, a novel reinforcement learning algorithm for diffusion models, as a subroutine in DxMI.

Abstract

We present a maximum entropy inverse reinforcement learning (IRL) approach for improving the sample quality of diffusion generative models, especially when the number of generation time steps is small. Similar to how IRL trains a policy based on the reward function learned from expert demonstrations, we train (or fine-tune) a diffusion model using the log probability density estimated from training data. Since we employ an energy-based model (EBM) to represent the log density, our approach boils down to the joint training of a diffusion model and an EBM. Our IRL formulation, named Diffusion by Maximum Entropy IRL (DxMI), is a minimax problem that reaches equilibrium when both models converge to the data distribution. The entropy maximization plays a key role in DxMI, facilitating the exploration of the diffusion model and ensuring the convergence of the EBM. We also propose Diffusion by Dynamic Programming (DxDP), a novel reinforcement learning algorithm for diffusion models, as a subroutine in DxMI. DxDP makes the diffusion model update in DxMI efficient by transforming the original problem into an optimal control formulation where value functions replace back-propagation in time. Our empirical studies show that diffusion models fine-tuned using DxMI can generate high-quality samples in as few as 4 and 10 steps. Additionally, DxMI enables the training of an EBM without MCMC, stabilizing EBM training dynamics and enhancing anomaly detection performance.

Maximum Entropy Inverse Reinforcement Learning of Diffusion Models with Energy-Based Models

TL;DR

DxMI enables the training of an EBM without MCMC, stabilizing EBM training dynamics and enhancing anomaly detection performance, and proposes Diffusion by Dynamic Programming, a novel reinforcement learning algorithm for diffusion models, as a subroutine in DxMI.

Abstract

We present a maximum entropy inverse reinforcement learning (IRL) approach for improving the sample quality of diffusion generative models, especially when the number of generation time steps is small. Similar to how IRL trains a policy based on the reward function learned from expert demonstrations, we train (or fine-tune) a diffusion model using the log probability density estimated from training data. Since we employ an energy-based model (EBM) to represent the log density, our approach boils down to the joint training of a diffusion model and an EBM. Our IRL formulation, named Diffusion by Maximum Entropy IRL (DxMI), is a minimax problem that reaches equilibrium when both models converge to the data distribution. The entropy maximization plays a key role in DxMI, facilitating the exploration of the diffusion model and ensuring the convergence of the EBM. We also propose Diffusion by Dynamic Programming (DxDP), a novel reinforcement learning algorithm for diffusion models, as a subroutine in DxMI. DxDP makes the diffusion model update in DxMI efficient by transforming the original problem into an optimal control formulation where value functions replace back-propagation in time. Our empirical studies show that diffusion models fine-tuned using DxMI can generate high-quality samples in as few as 4 and 10 steps. Additionally, DxMI enables the training of an EBM without MCMC, stabilizing EBM training dynamics and enhancing anomaly detection performance.
Paper Structure (37 sections, 19 equations, 6 figures, 6 tables, 2 algorithms)

This paper contains 37 sections, 19 equations, 6 figures, 6 tables, 2 algorithms.

Figures (6)

  • Figure 1: (Left) Overview of DxMI. The diffusion model $\pi(\mathbf{x})$ is trained using the energy of $q(\mathbf{x})$ as a reward. The EBM $q(\mathbf{x})$ is trained using samples from $\pi(\mathbf{x})$ as negative samples. (Right) ImageNet 64 generation examples from a 10-step diffusion model before DxMI fine-tuning (up) and after fine-tuning (down). Only the last six steps out of ten are shown.
  • Figure 2: 2D density estimation on 8 Gaussians. Red shades indicate the energy (white is low), and the dots are generated samples.
  • Figure 3: Value functions at each time step ($\tau=0.1$ case). Blue indicates a low value.
  • Figure 4: Randomly selected samples from CIFAR-10 training data, SFT-PG ($T=10$, FID: 4.32), and DxMI ($T=10$, FID: 3.19).
  • Figure 5: Randomly selected samples from ImageNet 64$\times$64 training data, Consistency Model ($T=1$, FID: 6.20), DxMI ($T=4$, FID: 3.21), and DxMI ($T=10$, FID: 2.68). Note that the Consistency Model samples distort human faces, while the DxMI samples depict them in correct proportions.
  • ...and 1 more figures