Table of Contents
Fetching ...

Group Distributionally Robust Dataset Distillation with Risk Minimization

Saeed Vahidian, Mingyu Wang, Jianyang Gu, Vyacheslav Kungurtsev, Wei Jiang, Yiran Chen

TL;DR

This work tackles generalization gaps in dataset distillation under subgroup and distributional shifts by introducing a double-distributionally robust optimization framework that couples clustering with CVaR-based risk minimization. By solving a two-level objective that minimizes loss across latent clusters while safeguarding the worst-case cluster performance, the method enhances both group robustness and within-group generalization. Theoretical grounding via Large Deviations Principles links CVaR approximations to DRO guarantees, and extensive experiments across robustness settings, subpopulation shifts, and cross-architecture evaluations demonstrate consistent gains over strong baselines. The approach advances practical DD by prioritizing representative coverage and resilience to domain shifts, with code and reproducibility details provided for broader adoption.

Abstract

Dataset distillation (DD) has emerged as a widely adopted technique for crafting a synthetic dataset that captures the essential information of a training dataset, facilitating the training of accurate neural models. Its applications span various domains, including transfer learning, federated learning, and neural architecture search. The most popular methods for constructing the synthetic data rely on matching the convergence properties of training the model with the synthetic dataset and the training dataset. However, using the empirical loss as the criterion must be thought of as auxiliary in the same sense that the training set is an approximate substitute for the population distribution, and the latter is the data of interest. Yet despite its popularity, an aspect that remains unexplored is the relationship of DD to its generalization, particularly across uncommon subgroups. That is, how can we ensure that a model trained on the synthetic dataset performs well when faced with samples from regions with low population density? Here, the representativeness and coverage of the dataset become salient over the guaranteed training error at inference. Drawing inspiration from distributionally robust optimization, we introduce an algorithm that combines clustering with the minimization of a risk measure on the loss to conduct DD. We provide a theoretical rationale for our approach and demonstrate its effective generalization and robustness across subgroups through numerical experiments. The source code is available at https://github.com/Mming11/RobustDatasetDistillation.

Group Distributionally Robust Dataset Distillation with Risk Minimization

TL;DR

This work tackles generalization gaps in dataset distillation under subgroup and distributional shifts by introducing a double-distributionally robust optimization framework that couples clustering with CVaR-based risk minimization. By solving a two-level objective that minimizes loss across latent clusters while safeguarding the worst-case cluster performance, the method enhances both group robustness and within-group generalization. Theoretical grounding via Large Deviations Principles links CVaR approximations to DRO guarantees, and extensive experiments across robustness settings, subpopulation shifts, and cross-architecture evaluations demonstrate consistent gains over strong baselines. The approach advances practical DD by prioritizing representative coverage and resilience to domain shifts, with code and reproducibility details provided for broader adoption.

Abstract

Dataset distillation (DD) has emerged as a widely adopted technique for crafting a synthetic dataset that captures the essential information of a training dataset, facilitating the training of accurate neural models. Its applications span various domains, including transfer learning, federated learning, and neural architecture search. The most popular methods for constructing the synthetic data rely on matching the convergence properties of training the model with the synthetic dataset and the training dataset. However, using the empirical loss as the criterion must be thought of as auxiliary in the same sense that the training set is an approximate substitute for the population distribution, and the latter is the data of interest. Yet despite its popularity, an aspect that remains unexplored is the relationship of DD to its generalization, particularly across uncommon subgroups. That is, how can we ensure that a model trained on the synthetic dataset performs well when faced with samples from regions with low population density? Here, the representativeness and coverage of the dataset become salient over the guaranteed training error at inference. Drawing inspiration from distributionally robust optimization, we introduce an algorithm that combines clustering with the minimization of a risk measure on the loss to conduct DD. We provide a theoretical rationale for our approach and demonstrate its effective generalization and robustness across subgroups through numerical experiments. The source code is available at https://github.com/Mming11/RobustDatasetDistillation.
Paper Structure (44 sections, 2 theorems, 26 equations, 9 figures, 14 tables, 2 algorithms)

This paper contains 44 sections, 2 theorems, 26 equations, 9 figures, 14 tables, 2 algorithms.

Key Result

Theorem 3.3

Under the circumstance by which the Morse-Saard condition holds souvcek1972morse and so the optimal set $\{\mathcal{S}^*(\theta)\},\{\theta^*(\mathcal{S})\}$ is compact (possibly finite) for all $\theta,\mathcal{S}$, then Algorithm alg:ddopt converges to a fixed point of equation eq:optprob.

Figures (9)

  • Figure 1: A visual representation of the robust inference task involves the partial partitioning of the population distribution, that is $\mathcal{X}$ across subgroups $\{\mathbf{c}_i\}$. A classifier is considered robust when it demonstrates high performance across the subgroups. As a practical hypothetical example of online learning, at a particular time, a steady stream of samples from $\mathcal{X}\vert \mathbf{c}_3$ may appear to the classifier. Note that in this case the region of sample space defined by this subgroup is geometrically small, and we can consider that it has a low overall prior density. If this subgroup's behavior is particularly anomalous, a model and any associated distilled dataset trained only on minimizing empirical risk may perform poorly on this subgroup.
  • Figure 2: Analysis on CVaR ratio $\alpha$.
  • Figure 3: T-SNE distribution visualization of original samples (blue dots) and synthesized samples (orange dots) on ImageNet bonnet class.
  • Figure 4: Synthesized sample visualization comparison between GLaD and our proposed method. The samples are initialized identically.
  • Figure 5: Top-1 robustness evaluation on IDM (CIFAR-10) and GLaD (TinyImageNet).
  • ...and 4 more figures

Theorems & Definitions (2)

  • Theorem 3.3
  • Proposition 3.4