Table of Contents
Fetching ...

Pruning then Reweighting: Towards Data-Efficient Training of Diffusion Models

Yize Li, Yihua Zhang, Sijia Liu, Xue Lin

TL;DR

This work investigates efficient diffusion training from the perspective of dataset pruning by extending the data selection scheme used in GANs to DM training, where data features are encoded by a surrogate model, and a score criterion is then applied to select the coreset.

Abstract

Despite the remarkable generation capabilities of Diffusion Models (DMs), conducting training and inference remains computationally expensive. Previous works have been devoted to accelerating diffusion sampling, but achieving data-efficient diffusion training has often been overlooked. In this work, we investigate efficient diffusion training from the perspective of dataset pruning. Inspired by the principles of data-efficient training for generative models such as generative adversarial networks (GANs), we first extend the data selection scheme used in GANs to DM training, where data features are encoded by a surrogate model, and a score criterion is then applied to select the coreset. To further improve the generation performance, we employ a class-wise reweighting approach, which derives class weights through distributionally robust optimization (DRO) over a pre-trained reference DM. For a pixel-wise DM (DDPM) on CIFAR-10, experiments demonstrate the superiority of our methodology over existing approaches and its effectiveness in image synthesis comparable to that of the original full-data model while achieving the speed-up between 2.34 times and 8.32 times. Additionally, our method could be generalized to latent DMs (LDMs), e.g., Masked Diffusion Transformer (MDT) and Stable Diffusion (SD), and achieves competitive generation capability on ImageNet. Code is available here (https://github.com/Yeez-lee/Data-Selection-and-Reweighting-for-Diffusion-Models).

Pruning then Reweighting: Towards Data-Efficient Training of Diffusion Models

TL;DR

This work investigates efficient diffusion training from the perspective of dataset pruning by extending the data selection scheme used in GANs to DM training, where data features are encoded by a surrogate model, and a score criterion is then applied to select the coreset.

Abstract

Despite the remarkable generation capabilities of Diffusion Models (DMs), conducting training and inference remains computationally expensive. Previous works have been devoted to accelerating diffusion sampling, but achieving data-efficient diffusion training has often been overlooked. In this work, we investigate efficient diffusion training from the perspective of dataset pruning. Inspired by the principles of data-efficient training for generative models such as generative adversarial networks (GANs), we first extend the data selection scheme used in GANs to DM training, where data features are encoded by a surrogate model, and a score criterion is then applied to select the coreset. To further improve the generation performance, we employ a class-wise reweighting approach, which derives class weights through distributionally robust optimization (DRO) over a pre-trained reference DM. For a pixel-wise DM (DDPM) on CIFAR-10, experiments demonstrate the superiority of our methodology over existing approaches and its effectiveness in image synthesis comparable to that of the original full-data model while achieving the speed-up between 2.34 times and 8.32 times. Additionally, our method could be generalized to latent DMs (LDMs), e.g., Masked Diffusion Transformer (MDT) and Stable Diffusion (SD), and achieves competitive generation capability on ImageNet. Code is available here (https://github.com/Yeez-lee/Data-Selection-and-Reweighting-for-Diffusion-Models).
Paper Structure (14 sections, 2 equations, 2 figures, 5 tables)

This paper contains 14 sections, 2 equations, 2 figures, 5 tables.

Figures (2)

  • Figure 1: Overview of our Data-Efficient Diffusion Training approach: Given input images (e.g., image size of 32$\times$32 for DDPM ho2020denoising and 256$\times$256 for MDT gao2023masked and SD rombach2021highresolution), we use a surrogate model as an encoder to obtain latent features and the following scoring function is to prune the datasets. The pre-trained reference DM facilitates the training of a proxy DM using distributionally robust optimization (DRO sagawa2020distributionallyxie2023doremi) across classes to generate class weights. Subsequently, DMs are trained on the weighted subset.
  • Figure 2: Class-wise DDPM FIDs (lower is better) before and after reweighting on 10% of CIFAR-10 training data. Reweighting leads to improved generation performance across all 10 classes; the class weights larger than the initial weight are from the easily generated categories.