Table of Contents
Fetching ...

Bring Metric Functions into Diffusion Models

Jie An, Zhengyuan Yang, Jianfeng Wang, Linjie Li, Zicheng Liu, Lijuan Wang, Jiebo Luo

TL;DR

The paper tackles how to leverages metric functions, notably LPIPS, to boost diffusion models. It introduces Cas-DM, a cascaded diffusion architecture that splits the task into predicting added noise $oldsymbol{b5}$ with a front-end network $ heta$ and refining the clean image $x_0$ with a back-end network $oldsymbol{1}$, using a dynamic weight $r_t$ to blend their contributions. By applying the metric function to the $x_0$ path while stopping gradients to the $oldsymbol{b5}$ path, Cas-DM preserves stable noise prediction and achieves improved image fidelity (FID, sFID) and competitive diversity (IS) across CIFAR-10, CelebA-HQ, LSUN Bedroom, and ImageNet. Experimental results show that Cas-DM, especially with LPIPS, delivers state-of-the-art or near-state-of-the-art performance, validating the architectural design for integrating metric functions into diffusion training.

Abstract

We introduce a Cascaded Diffusion Model (Cas-DM) that improves a Denoising Diffusion Probabilistic Model (DDPM) by effectively incorporating additional metric functions in training. Metric functions such as the LPIPS loss have been proven highly effective in consistency models derived from the score matching. However, for the diffusion counterparts, the methodology and efficacy of adding extra metric functions remain unclear. One major challenge is the mismatch between the noise predicted by a DDPM at each step and the desired clean image that the metric function works well on. To address this problem, we propose Cas-DM, a network architecture that cascades two network modules to effectively apply metric functions to the diffusion model training. The first module, similar to a standard DDPM, learns to predict the added noise and is unaffected by the metric function. The second cascaded module learns to predict the clean image, thereby facilitating the metric function computation. Experiment results show that the proposed diffusion model backbone enables the effective use of the LPIPS loss, leading to state-of-the-art image quality (FID, sFID, IS) on various established benchmarks.

Bring Metric Functions into Diffusion Models

TL;DR

The paper tackles how to leverages metric functions, notably LPIPS, to boost diffusion models. It introduces Cas-DM, a cascaded diffusion architecture that splits the task into predicting added noise with a front-end network and refining the clean image with a back-end network , using a dynamic weight to blend their contributions. By applying the metric function to the path while stopping gradients to the path, Cas-DM preserves stable noise prediction and achieves improved image fidelity (FID, sFID) and competitive diversity (IS) across CIFAR-10, CelebA-HQ, LSUN Bedroom, and ImageNet. Experimental results show that Cas-DM, especially with LPIPS, delivers state-of-the-art or near-state-of-the-art performance, validating the architectural design for integrating metric functions into diffusion training.

Abstract

We introduce a Cascaded Diffusion Model (Cas-DM) that improves a Denoising Diffusion Probabilistic Model (DDPM) by effectively incorporating additional metric functions in training. Metric functions such as the LPIPS loss have been proven highly effective in consistency models derived from the score matching. However, for the diffusion counterparts, the methodology and efficacy of adding extra metric functions remain unclear. One major challenge is the mismatch between the noise predicted by a DDPM at each step and the desired clean image that the metric function works well on. To address this problem, we propose Cas-DM, a network architecture that cascades two network modules to effectively apply metric functions to the diffusion model training. The first module, similar to a standard DDPM, learns to predict the added noise and is unaffected by the metric function. The second cascaded module learns to predict the clean image, thereby facilitating the metric function computation. Experiment results show that the proposed diffusion model backbone enables the effective use of the LPIPS loss, leading to state-of-the-art image quality (FID, sFID, IS) on various established benchmarks.
Paper Structure (12 sections, 17 equations, 4 figures, 5 tables)

This paper contains 12 sections, 17 equations, 4 figures, 5 tables.

Figures (4)

  • Figure 1: We introduce a cascaded diffusion model that can effectively incorporate metric functions in diffusion training. (a) DDPM outputs either $\epsilon^\prime$ or $x_0^\prime$ and uses the corresponding loss in training. (b) Dual Diffusion Model outputs both $\epsilon^\prime$ and $x_0^\prime$ simultaneously with a single network $\theta$, where applying metric functions on $x_0^\prime$ will inevitably influence the prediction of $\epsilon^\prime$. (c) Our Cas-DM cascades the main module $\theta$ with an extra network $\phi$, where $\theta$ is frozen for the $x_0^\prime$-related losses and metric functions.
  • Figure 2: Framework of the proposed Cas-DM. For each time step $t$ from $T$ to $1$, $\theta$ takes $x_t$ and $t$ as the inputs and estimates the added noise $\epsilon^\prime$, which is then converted into an estimation of the clean image $x_0^\star$. Next, $\phi$ outputs the $x_0^\prime$ and $r_t$ based on $x_0^\star$ and $t$, where the former is the final clean image estimation. $r_t$ is then used to mix the $\mu$ estimations from $x_0^\prime$ and $\epsilon^\prime$. Cas-DM uses DDIM to run one backward step based on $\mu^{ddim}$, getting $x_{t-1}$. Cas-DM runs the above process for $T-1$ rounds and gradually generates a clean image starting from a noise sample.
  • Figure 3: Training process of Cas-DM. $\theta$ learns to estimate the added noise $\epsilon$ while $\phi$ is trained to predict the clean image $x_0$. We apply $L_t^\epsilon$ on $\theta$ and all the gradients of other losses are blocked for it. For $\phi$, we use $L_t^{x_0}$, $L_t^{lpips}$, and $L_t^{\mu}$ losses, where the first two is to enforce $\phi$ to recover the clean image from $x_0^\star$, assisted by the the LPIPS loss. $L_t^{\mu}$ is to train the dynamic mixing weight and the gradient is stopped before $\mu_{\epsilon^\prime}$ and $\mu_{x_0^\prime}$. Best viewed on screen by zoom-in.
  • Figure 4: Unconditional samples from Cas-DM trained with the LPIPS loss on the experimented datasets.