Table of Contents
Fetching ...

Dataset Distillation via the Wasserstein Metric

Haoyang Liu, Yijiang Li, Tiancheng Xing, Peiran Wang, Vibhu Dalal, Luwei Li, Jingrui He, Haohan Wang

TL;DR

WMDD (Wasserstein Metric-based Dataset Distillation), a straightforward yet powerful method that employs the Wasserstein metric to enhance distribution matching that maintains the efficiency of distribution matching approaches while achieving state-of-the-art results across various high-resolution datasets.

Abstract

Dataset Distillation (DD) aims to generate a compact synthetic dataset that enables models to achieve performance comparable to training on the full large dataset, significantly reducing computational costs. Drawing from optimal transport theory, we introduce WMDD (Wasserstein Metric-based Dataset Distillation), a straightforward yet powerful method that employs the Wasserstein metric to enhance distribution matching. We compute the Wasserstein barycenter of features from a pretrained classifier to capture essential characteristics of the original data distribution. By optimizing synthetic data to align with this barycenter in feature space and leveraging per-class BatchNorm statistics to preserve intra-class variations, WMDD maintains the efficiency of distribution matching approaches while achieving state-of-the-art results across various high-resolution datasets. Our extensive experiments demonstrate WMDD's effectiveness and adaptability, highlighting its potential for advancing machine learning applications at scale.

Dataset Distillation via the Wasserstein Metric

TL;DR

WMDD (Wasserstein Metric-based Dataset Distillation), a straightforward yet powerful method that employs the Wasserstein metric to enhance distribution matching that maintains the efficiency of distribution matching approaches while achieving state-of-the-art results across various high-resolution datasets.

Abstract

Dataset Distillation (DD) aims to generate a compact synthetic dataset that enables models to achieve performance comparable to training on the full large dataset, significantly reducing computational costs. Drawing from optimal transport theory, we introduce WMDD (Wasserstein Metric-based Dataset Distillation), a straightforward yet powerful method that employs the Wasserstein metric to enhance distribution matching. We compute the Wasserstein barycenter of features from a pretrained classifier to capture essential characteristics of the original data distribution. By optimizing synthetic data to align with this barycenter in feature space and leveraging per-class BatchNorm statistics to preserve intra-class variations, WMDD maintains the efficiency of distribution matching approaches while achieving state-of-the-art results across various high-resolution datasets. Our extensive experiments demonstrate WMDD's effectiveness and adaptability, highlighting its potential for advancing machine learning applications at scale.
Paper Structure (46 sections, 29 equations, 11 figures, 10 tables, 2 algorithms)

This paper contains 46 sections, 29 equations, 11 figures, 10 tables, 2 algorithms.

Figures (11)

  • Figure 1: Synthetic images distilled from ImageNet-1K using our WMDD method with ResNet-18, capturing essential class features aligned with human perception. We randomly sampled one image for each of the chosen categories from our output in the 10 IPC setting.
  • Figure 2: The capability of Wasserstein barycenter in condensing the core characteristics of distributions: (a) distributions defined on $\mathbb{R}^2$, concentrated on outlines of circles (blue) and crosses (green). Barycenters computed using: (b) KL divergence, (c) Maximum Mean Discrepancy (MMD), which operates in a kernel-induced feature space, and (d) Wasserstein distance, which preserves geometric structure through optimal transport. Color intensity represents probability density, while color hue shows different types of source distributions.
  • Figure 3: Diagram of our WMDD method. Real dataset $T$ and synthetic dataset $S$ pass through the feature network $f$ to obtain features. The features of the real dataset are used to compute the Wasserstein Barycenter. The synthetic dataset is optimized via feature matching and loss computation (combining feature loss and BN regularization) to align with the Barycenter, generating high-quality synthetic data for efficient model training.
  • Figure 4:
  • Figure 5: Distribution visualization of ImageNette. The dots present the original dataset's distribution using the model's latent space (e.g., ResNet-101), and the triangles are distilled images. Left: data distilled by SRe$^2$L; Right: data distilled by our method.
  • ...and 6 more figures