Table of Contents
Fetching ...

Diffusion Model Patching via Mixture-of-Prompts

Seokil Ham, Sangmin Woo, Jin-Young Kim, Hyojun Go, Byeongjun Park, Changick Kim

TL;DR

This work tackles the challenge of improving already-converged diffusion models without extensive retraining. It proposes Diffusion Model Patching (DMP), which injects a small pool of learnable prompts into the input and uses a dynamic gating mechanism to form timestep-specific mixtures, while keeping the backbone frozen and training only on the original pre-training data. Key contributions include zero-initialized prompts, prompt-balancing losses to prevent mode collapse, and a gating strategy that enables stage-aware denoising across thousands of steps, achieving a $10.38\%$ FID improvement on FFHQ with just $50{,}000$ iterations and a $1.43\%$ parameter increase. The method generalizes across architectures (e.g., DiT-L/2, Stable Diffusion, and DiT-XL/2) and tasks, offering a practical, data-efficient, and scalable approach to enhance diffusion models in-domain.

Abstract

We present Diffusion Model Patching (DMP), a simple method to boost the performance of pre-trained diffusion models that have already reached convergence, with a negligible increase in parameters. DMP inserts a small, learnable set of prompts into the model's input space while keeping the original model frozen. The effectiveness of DMP is not merely due to the addition of parameters but stems from its dynamic gating mechanism, which selects and combines a subset of learnable prompts at every timestep (i.e., reverse denoising steps). This strategy, which we term "mixture-of-prompts", enables the model to draw on the distinct expertise of each prompt, essentially "patching" the model's functionality at every timestep with minimal yet specialized parameters. Uniquely, DMP enhances the model by further training on the original dataset already used for pre-training, even in a scenario where significant improvements are typically not expected due to model convergence. Notably, DMP significantly enhances the FID of converged DiT-L/2 by 10.38% on FFHQ, achieved with only a 1.43% parameter increase and 50K additional training iterations.

Diffusion Model Patching via Mixture-of-Prompts

TL;DR

This work tackles the challenge of improving already-converged diffusion models without extensive retraining. It proposes Diffusion Model Patching (DMP), which injects a small pool of learnable prompts into the input and uses a dynamic gating mechanism to form timestep-specific mixtures, while keeping the backbone frozen and training only on the original pre-training data. Key contributions include zero-initialized prompts, prompt-balancing losses to prevent mode collapse, and a gating strategy that enables stage-aware denoising across thousands of steps, achieving a FID improvement on FFHQ with just iterations and a parameter increase. The method generalizes across architectures (e.g., DiT-L/2, Stable Diffusion, and DiT-XL/2) and tasks, offering a practical, data-efficient, and scalable approach to enhance diffusion models in-domain.

Abstract

We present Diffusion Model Patching (DMP), a simple method to boost the performance of pre-trained diffusion models that have already reached convergence, with a negligible increase in parameters. DMP inserts a small, learnable set of prompts into the model's input space while keeping the original model frozen. The effectiveness of DMP is not merely due to the addition of parameters but stems from its dynamic gating mechanism, which selects and combines a subset of learnable prompts at every timestep (i.e., reverse denoising steps). This strategy, which we term "mixture-of-prompts", enables the model to draw on the distinct expertise of each prompt, essentially "patching" the model's functionality at every timestep with minimal yet specialized parameters. Uniquely, DMP enhances the model by further training on the original dataset already used for pre-training, even in a scenario where significant improvements are typically not expected due to model convergence. Notably, DMP significantly enhances the FID of converged DiT-L/2 by 10.38% on FFHQ, achieved with only a 1.43% parameter increase and 50K additional training iterations.
Paper Structure (34 sections, 13 equations, 11 figures, 9 tables)

This paper contains 34 sections, 13 equations, 11 figures, 9 tables.

Figures (11)

  • Figure 1: Further training of the fully converged DiT-L/2 model using the same dataset as the pre-training phase. Our method, DMP achieves a 10.38% FID improvement in just 50K iterations, while other methods exhibit overfitting.
  • Figure 2: Overview of DMP. We take inspiration from prompt tuning lester2021power and aim to enhance already converged diffusion models. Our approach incorporates a pool of prompts within the input space, with each prompt learned to excel at certain stages of the denoising process. At every step, a unique blend of prompts (i.e., mixture-of-prompts) is constructed via dynamic gating based on the current noise level. This mechanism is similar to an skilled artist choosing the appropriate color combinations to refine different aspects of their artwork for specific moments. Importantly, our method keeps the diffusion model itself unchanged, and only use the original training dataset for further training.
  • Figure 3: DMP framework with DiT peebles2022scalable. DMP is designed to adaptively generate optimal prompts tailored to specific timesteps. DMP uses the original training dataset---previously used for pre-training diffusion models---for fine-tuning. Operating entirely through prompt-based tuning in the input space, DMP eliminates the need for modifications to either the model architecture or the overall training process, ensuring seamless integration and efficiency.
  • Figure 4: Prompt depth.
  • Figure 5: Prompt activation. Brighter indicates stronger.
  • ...and 6 more figures