Table of Contents
Fetching ...

Decouple-Then-Merge: Finetune Diffusion Models as Multi-Task Learning

Qianli Ma, Xuefei Ning, Dongrui Liu, Li Niu, Linfeng Zhang

TL;DR

A Decouple-then-Merge (DeMe) framework is proposed, which begins with a pretrained model and finetunes separate models tailored to specific timesteps, which can be merged into a single model in the parameter space, ensuring efficient and practical inference.

Abstract

Diffusion models are trained by learning a sequence of models that reverse each step of noise corruption. Typically, the model parameters are fully shared across multiple timesteps to enhance training efficiency. However, since the denoising tasks differ at each timestep, the gradients computed at different timesteps may conflict, potentially degrading the overall performance of image generation. To solve this issue, this work proposes a \textbf{De}couple-then-\textbf{Me}rge (\textbf{DeMe}) framework, which begins with a pretrained model and finetunes separate models tailored to specific timesteps. We introduce several improved techniques during the finetuning stage to promote effective knowledge sharing while minimizing training interference across timesteps. Finally, after finetuning, these separate models can be merged into a single model in the parameter space, ensuring efficient and practical inference. Experimental results show significant generation quality improvements upon 6 benchmarks including Stable Diffusion on COCO30K, ImageNet1K, PartiPrompts, and DDPM on LSUN Church, LSUN Bedroom, and CIFAR10. Code is available at \href{https://github.com/MqLeet/DeMe}{GitHub}.

Decouple-Then-Merge: Finetune Diffusion Models as Multi-Task Learning

TL;DR

A Decouple-then-Merge (DeMe) framework is proposed, which begins with a pretrained model and finetunes separate models tailored to specific timesteps, which can be merged into a single model in the parameter space, ensuring efficient and practical inference.

Abstract

Diffusion models are trained by learning a sequence of models that reverse each step of noise corruption. Typically, the model parameters are fully shared across multiple timesteps to enhance training efficiency. However, since the denoising tasks differ at each timestep, the gradients computed at different timesteps may conflict, potentially degrading the overall performance of image generation. To solve this issue, this work proposes a \textbf{De}couple-then-\textbf{Me}rge (\textbf{DeMe}) framework, which begins with a pretrained model and finetunes separate models tailored to specific timesteps. We introduce several improved techniques during the finetuning stage to promote effective knowledge sharing while minimizing training interference across timesteps. Finally, after finetuning, these separate models can be merged into a single model in the parameter space, ensuring efficient and practical inference. Experimental results show significant generation quality improvements upon 6 benchmarks including Stable Diffusion on COCO30K, ImageNet1K, PartiPrompts, and DDPM on LSUN Church, LSUN Bedroom, and CIFAR10. Code is available at \href{https://github.com/MqLeet/DeMe}{GitHub}.

Paper Structure

This paper contains 32 sections, 23 equations, 14 figures, 4 tables.

Figures (14)

  • Figure 1: (a) Cosine similarity between gradients at different timesteps on CIFAR10 & distribution of gradients similarity in $t\in [0,1000]$ and $t\in [0,250]$. Non-adjacent timesteps have low similarity, indicating conflicts during their training. In contrast, adjacent timesteps have similar gradients. (b) & (c): Comparison between the traditional and our training paradigm: The previous paradigm trains one diffusion model on all timesteps, leading to conflicts in different timesteps. Our method addresses this problem by decoupling the training of diffusion models in $N$ different timestep ranges.
  • Figure 2: Visualization of the difference between the pre-finetuned and the post-finetuned diffusion model on the channel and spatial dimensions. We computed the difference in activation values before/after finetune along the channel and spatial dimensions of the image. (a) Visualization of channel activation, spatial activation, and their difference between the pre-finetuned and the post-finetuned model. (b) Distribution of difference for channel activation and spatial activation values. It can be observed that activation values vary mostly in channel dimensions during finetuning on a subset of timesteps.
  • Figure 3: Pipeline of our framework. The following training techniques are incorporated into the finetuning process: Consistency loss preserves the original knowledge of diffusion models learned at all timesteps by minimizing the difference between pre-finetuned and post-finetuned diffusion models. Probabilistic sampling strategy samples from both the corresponding and other timesteps with different probabilities, helping the diffusion model overcome forgetting knowledge from other timesteps. Channel-wise projection enables the diffusion model to directly capture the feature difference in channel dimension. Model merging scheme merges the parameters of all the finetuned models into one unified model to promote the knowledge sharing across different timestep ranges.
  • Figure 4: Loss landscape of the pretrained diffusion model in different timestep ranges on CIFAR10. We use dimension reduction methods to handle high-dimensional neural networks. Contour line density reflects the frequency of loss variations (i.e., gradients), with blue representing low loss and red representing high loss. The pretrained model resides at the critical point (with zero gradients) with sparse contour lines for the overall timesteps $t \in [0, 1000)$, but when the training process is decoupled, it tends to be located in regions with densely packed contour lines, suggesting that there still exists gradients that enable pretrained model to escape from the critical point.
  • Figure 5: Qualitative comparison between DeMe and the original Stable Diffusion on various prompts. More images based on various text prompts could be found in supplementary material.
  • ...and 9 more figures