Table of Contents
Fetching ...

Avoiding mode collapse in diffusion models fine-tuned with reinforcement learning

Roberto Barceló, Cristóbal Alcázar, Felipe Tobar

TL;DR

The paper tackles mode collapse and instability when fine-tuning diffusion models with reinforcement learning by introducing Hierarchical Reward Fine-tuning (HRF), a sliding-window, step-wise RL framework that exploits the hierarchical diffusion dynamics. HRF (and its dynamic variant HRF-D) performs trajectory-of-interest sampling across timesteps, enabling online RL updates at selected noise levels with appropriate diversity pressures, while preserving high-level semantics. Empirical results on a CelebA-HQ-based diffusion model across compressibility, incompressibility, and LAION aesthetic tasks show that HRF achieves reward levels comparable to DDPO but with substantially improved diversity, as evidenced by Inception Score and Vendi Score near baseline. Overall, HRF provides a robust, hierarchical, and tunable approach to aligning diffusion models to downstream objectives without sacrificing sample diversity, enabling safer and more reliable deployment in conditional generation tasks.

Abstract

Fine-tuning foundation models via reinforcement learning (RL) has proven promising for aligning to downstream objectives. In the case of diffusion models (DMs), though RL training improves alignment from early timesteps, critical issues such as training instability and mode collapse arise. We address these drawbacks by exploiting the hierarchical nature of DMs: we train them dynamically at each epoch with a tailored RL method, allowing for continual evaluation and step-by-step refinement of the model performance (or alignment). Furthermore, we find that not every denoising step needs to be fine-tuned to align DMs to downstream tasks. Consequently, in addition to clipping, we regularise model parameters at distinct learning phases via a sliding-window approach. Our approach, termed Hierarchical Reward Fine-tuning (HRF), is validated on the Denoising Diffusion Policy Optimisation method, where we show that models trained with HRF achieve better preservation of diversity in downstream tasks, thus enhancing the fine-tuning robustness and at uncompromising mean rewards.

Avoiding mode collapse in diffusion models fine-tuned with reinforcement learning

TL;DR

The paper tackles mode collapse and instability when fine-tuning diffusion models with reinforcement learning by introducing Hierarchical Reward Fine-tuning (HRF), a sliding-window, step-wise RL framework that exploits the hierarchical diffusion dynamics. HRF (and its dynamic variant HRF-D) performs trajectory-of-interest sampling across timesteps, enabling online RL updates at selected noise levels with appropriate diversity pressures, while preserving high-level semantics. Empirical results on a CelebA-HQ-based diffusion model across compressibility, incompressibility, and LAION aesthetic tasks show that HRF achieves reward levels comparable to DDPO but with substantially improved diversity, as evidenced by Inception Score and Vendi Score near baseline. Overall, HRF provides a robust, hierarchical, and tunable approach to aligning diffusion models to downstream objectives without sacrificing sample diversity, enabling safer and more reliable deployment in conditional generation tasks.

Abstract

Fine-tuning foundation models via reinforcement learning (RL) has proven promising for aligning to downstream objectives. In the case of diffusion models (DMs), though RL training improves alignment from early timesteps, critical issues such as training instability and mode collapse arise. We address these drawbacks by exploiting the hierarchical nature of DMs: we train them dynamically at each epoch with a tailored RL method, allowing for continual evaluation and step-by-step refinement of the model performance (or alignment). Furthermore, we find that not every denoising step needs to be fine-tuned to align DMs to downstream tasks. Consequently, in addition to clipping, we regularise model parameters at distinct learning phases via a sliding-window approach. Our approach, termed Hierarchical Reward Fine-tuning (HRF), is validated on the Denoising Diffusion Policy Optimisation method, where we show that models trained with HRF achieve better preservation of diversity in downstream tasks, thus enhancing the fine-tuning robustness and at uncompromising mean rewards.

Paper Structure

This paper contains 27 sections, 9 equations, 7 figures, 5 tables, 2 algorithms.

Figures (7)

  • Figure 1: Comparison of Image Synthesis Using CelebA-HQ-Based Models. 2D projection of CLIP embeddings for two sets of 1,000 samples: i) DDPM samples (black borders) and ii) DDPO samples fine-tuned with the LAION aesthetic reward (white borders). The DDPO samples were optimized to achieve a higher average aesthetic score ($5.58$ vs. $5.11$), indicating better aesthetic quality. Notably, the DDPO samples cluster more tightly (red ellipse) around the highest-scoring DDPM sample, indicating a mode collapse effect. Both sets of samples were generated using the same seed.
  • Figure 2: Equivalence of the backward process of a diffusion model as a sequential decision-making process. The initial state distribution of this MDP corresponds to an isotropic Gaussian, $\rho_{0}(s_{0})\sim\mathcal{N}(0, \mathrm{I})$, where we assign the noise instance to the initial state $s_{0}=\mathrm{x}_{T}$. The agent follows a sequence of decisions $a_{t}$ determined by the policy $\pi_{\theta}(a_{t}\mid s_{t}):=p_{\theta}(\mathrm{x}_{T-t-1}\mid \mathrm{x}_{T-t})$, moving from a noisy state $\mathrm{x}_{T-t}$ to a less noisy one $\mathrm{x}_{T-t-1}$ until it reaches to the sample $\mathrm{x}_{0}$, illustrated as the terminal state $s_{T}$ in the diagram. This process generates the whole denoising trajectory $\tau=\{\mathrm{x}_{T}, \mathrm{x}_{T-1}, \mathrm{x}_{T-2}, \dots, \mathrm{x}_{0}\}$, which is associated with a reward. In the case of DDPO, the reward model $R$ only depends on the final, i.e., sample $r(\mathrm{x}_{0})$.
  • Figure 3: Hierarchical Reward Fine-tuning (HRF). In DDPO (purple), the entire denoising trajectory is influenced, affecting both high- and low-level features. In contrast, HRF intervenes at later timesteps of the trajectory, selecting an specific timestep $t$ based on a window selection schema. At this point, an intermediate state $s_{t}\sim\rho_{t}$ is drawn from a new prior distribution, primarily adjusting low-level features while preserving high-level features and diversity, yet still achieving high rewards ($s_{T}^{1}$ yellow star). During policy rollouts, HRF generates divergent trajectories from $s_{t}$ by introducing noise at intermediate timesteps. This serves as a low-level feature exploration mechanism to discover regions with higher reward potential ($s_{T}^{2}$ yellow star) given the high-level information set by the new prior $\rho_{t}$. In both cases, the dataset of trajectories and rewards $\mathcal{D}^{\pi_{\theta}}$ is used to estimate the gradients via Monte Carlo sampling, which are then applied to update the diffusion model parameters. HRF-D dynamically adjusts the vertical window selection line during fine-tuning.
  • Figure 4: Alignment of diffusion model to downstream tasks. This figure shows the visual performance of a diffusion model on three tasks: aesthetic quality, incompressibility, and compressibility. DDPO and HRF achieve similar semantic changes, but HRF better preserves visual diversity, especially for aesthetic quality. While DDPO risks mode collapse by generating similar high-reward images, HRF maintains diversity while improving rewards. HRF-D shows a significant visual shift, with samples differing greatly from the originals but maintaining high diversity. Compressibility skews towards darker samples, yet retains diverse representations.
  • Figure 5: Rewards (y-axis) vs Diversity (x-axis, Vendi Score). Ablations on window selections for three hierarchical regimes: baseline (blue), early (orange) and latter stages (green) over the three downstream tasks considered. Each result is reported as the average of three-run seeds.
  • ...and 2 more figures