Table of Contents
Fetching ...

Switch Diffusion Transformer: Synergizing Denoising Tasks with Sparse Mixture-of-Experts

Byeongjun Park, Hyojun Go, Jin-Young Kim, Sangmin Woo, Seokil Ham, Changick Kim

TL;DR

This work introduces Switch Diffusion Transformer (Switch-DiT), which establishes inter-task relationships between conflicting tasks without compromising semantic information and proposes a diffusion prior loss, encouraging similar tasks to share their denoising paths while isolating conflicting ones.

Abstract

Diffusion models have achieved remarkable success across a range of generative tasks. Recent efforts to enhance diffusion model architectures have reimagined them as a form of multi-task learning, where each task corresponds to a denoising task at a specific noise level. While these efforts have focused on parameter isolation and task routing, they fall short of capturing detailed inter-task relationships and risk losing semantic information, respectively. In response, we introduce Switch Diffusion Transformer (Switch-DiT), which establishes inter-task relationships between conflicting tasks without compromising semantic information. To achieve this, we employ a sparse mixture-of-experts within each transformer block to utilize semantic information and facilitate handling conflicts in tasks through parameter isolation. Additionally, we propose a diffusion prior loss, encouraging similar tasks to share their denoising paths while isolating conflicting ones. Through these, each transformer block contains a shared expert across all tasks, where the common and task-specific denoising paths enable the diffusion model to construct its beneficial way of synergizing denoising tasks. Extensive experiments validate the effectiveness of our approach in improving both image quality and convergence rate, and further analysis demonstrates that Switch-DiT constructs tailored denoising paths across various generation scenarios.

Switch Diffusion Transformer: Synergizing Denoising Tasks with Sparse Mixture-of-Experts

TL;DR

This work introduces Switch Diffusion Transformer (Switch-DiT), which establishes inter-task relationships between conflicting tasks without compromising semantic information and proposes a diffusion prior loss, encouraging similar tasks to share their denoising paths while isolating conflicting ones.

Abstract

Diffusion models have achieved remarkable success across a range of generative tasks. Recent efforts to enhance diffusion model architectures have reimagined them as a form of multi-task learning, where each task corresponds to a denoising task at a specific noise level. While these efforts have focused on parameter isolation and task routing, they fall short of capturing detailed inter-task relationships and risk losing semantic information, respectively. In response, we introduce Switch Diffusion Transformer (Switch-DiT), which establishes inter-task relationships between conflicting tasks without compromising semantic information. To achieve this, we employ a sparse mixture-of-experts within each transformer block to utilize semantic information and facilitate handling conflicts in tasks through parameter isolation. Additionally, we propose a diffusion prior loss, encouraging similar tasks to share their denoising paths while isolating conflicting ones. Through these, each transformer block contains a shared expert across all tasks, where the common and task-specific denoising paths enable the diffusion model to construct its beneficial way of synergizing denoising tasks. Extensive experiments validate the effectiveness of our approach in improving both image quality and convergence rate, and further analysis demonstrates that Switch-DiT constructs tailored denoising paths across various generation scenarios.
Paper Structure (45 sections, 11 equations, 10 figures, 8 tables)

This paper contains 45 sections, 11 equations, 10 figures, 8 tables.

Figures (10)

  • Figure 1: Switch Diffusion Transformer.$\odot$ represents an element-wise multiplication. Switch-DiT is built upon the DiT peebles2022scalable architecture, which consists of the self-attention and the feedforward, both conditioned on timestep embeddings and label embeddings via the adaLN-Zero layer. In the SMoE layer, the gating network takes the timestep embeddings and selects two out of three experts to output ${\bm{m}}({\bm{z}})$. Then, ${\bm{z}} \cdot {\bm{m}}({\bm{z}})$ is used as input to the transformer block, and ${\bm{z}} \cdot (1 - {\bm{m}}({\bm{z}}))$ is skip-connected to the end.
  • Figure 2: Gating Outputs Integration. For simplicity, we visualize the gating outputs for three experts and select the largest two elements within each transformer block. As a result, ${\bm{p}}_{tot}({\bm{e}}_{t})$ is a concatenated probability of each ${\bm{p}}_{i}({\bm{e}}_{t})$ for $i$-th block, which is then used for the diffusion prior loss. Also, $w_{t}^{gate}$ is used for a cost function of the bipartite matching with that similarly derived from the DTR park2023denoising.
  • Figure 3: Bipartite Matching. We show the stacked ${\bm{w}}^{gate}_{t}$ and ${\bm{w}}^{prior}_{t}$ for $N=24$, $M=3$ and $k=2$. Each row represents a concatenated activation map as shown in Fig. \ref{['fig:integration_gating']}.
  • Figure 4: Correlation of GFLOPs and FID on ImageNet. Switch-DiT transcends the tradeoff of DiT.
  • Figure 5: Convergence comparison on ImageNet. Switch-DiT achieves the fastest convergence rates of diffusion training across different model sizes (S, B and XL).
  • ...and 5 more figures