Table of Contents
Fetching ...

Enhancing Generalization via Sharpness-Aware Trajectory Matching for Dataset Condensation

Boyan Gao, Bo Zhao, Shreyank N Gowda, Xingrun Xing, Yibo Yang, Timothy Hospedales, David A. Clifton

TL;DR

The paper tackles dataset condensation with long-horizon bilevel optimization, where learning synthetic data via trajectory matching often generalizes poorly and incurs heavy compute. It introduces Sharpness-Aware Trajectory Matching (SATM), which jointly minimizes outer-loop loss sharpness and distances between training trajectories, and pairs it with efficient hypergradient tricks: Truncated Unrolling Hypergradient (TUH) and Trajectory Reusing (TR), plus a smoothing mechanism that uses Gaussian perturbations. A closed-form learning-rate update further reduces backpropagation costs. Empirical results across CIFAR-10/100, Tiny ImageNet, and ImageNet subsets demonstrate improved in-domain and out-of-domain generalization with favorable time/memory trade-offs, including strong cross-architecture performance and continual learning gains, highlighting SATM’s practical impact for scalable dataset condensation.

Abstract

Dataset condensation aims to synthesize datasets with a few representative samples that can effectively represent the original datasets. This enables efficient training and produces models with performance close to those trained on the original sets. Most existing dataset condensation methods conduct dataset learning under the bilevel (inner- and outer-loop) based optimization. However, the preceding methods perform with limited dataset generalization due to the notoriously complicated loss landscape and expensive time-space complexity of the inner-loop unrolling of bilevel optimization. These issues deteriorate when the datasets are learned via matching the trajectories of networks trained on the real and synthetic datasets with a long horizon inner-loop. To address these issues, we introduce Sharpness-Aware Trajectory Matching (SATM), which enhances the generalization capability of learned synthetic datasets by optimising the sharpness of the loss landscape and objective simultaneously. Moreover, our approach is coupled with an efficient hypergradient approximation that is mathematically well-supported and straightforward to implement along with controllable computational overhead. Empirical evaluations of SATM demonstrate its effectiveness across various applications, including in-domain benchmarks and out-of-domain settings. Moreover, its easy-to-implement properties afford flexibility, allowing it to integrate with other advanced sharpness-aware minimizers. Our code will be released.

Enhancing Generalization via Sharpness-Aware Trajectory Matching for Dataset Condensation

TL;DR

The paper tackles dataset condensation with long-horizon bilevel optimization, where learning synthetic data via trajectory matching often generalizes poorly and incurs heavy compute. It introduces Sharpness-Aware Trajectory Matching (SATM), which jointly minimizes outer-loop loss sharpness and distances between training trajectories, and pairs it with efficient hypergradient tricks: Truncated Unrolling Hypergradient (TUH) and Trajectory Reusing (TR), plus a smoothing mechanism that uses Gaussian perturbations. A closed-form learning-rate update further reduces backpropagation costs. Empirical results across CIFAR-10/100, Tiny ImageNet, and ImageNet subsets demonstrate improved in-domain and out-of-domain generalization with favorable time/memory trade-offs, including strong cross-architecture performance and continual learning gains, highlighting SATM’s practical impact for scalable dataset condensation.

Abstract

Dataset condensation aims to synthesize datasets with a few representative samples that can effectively represent the original datasets. This enables efficient training and produces models with performance close to those trained on the original sets. Most existing dataset condensation methods conduct dataset learning under the bilevel (inner- and outer-loop) based optimization. However, the preceding methods perform with limited dataset generalization due to the notoriously complicated loss landscape and expensive time-space complexity of the inner-loop unrolling of bilevel optimization. These issues deteriorate when the datasets are learned via matching the trajectories of networks trained on the real and synthetic datasets with a long horizon inner-loop. To address these issues, we introduce Sharpness-Aware Trajectory Matching (SATM), which enhances the generalization capability of learned synthetic datasets by optimising the sharpness of the loss landscape and objective simultaneously. Moreover, our approach is coupled with an efficient hypergradient approximation that is mathematically well-supported and straightforward to implement along with controllable computational overhead. Empirical evaluations of SATM demonstrate its effectiveness across various applications, including in-domain benchmarks and out-of-domain settings. Moreover, its easy-to-implement properties afford flexibility, allowing it to integrate with other advanced sharpness-aware minimizers. Our code will be released.

Paper Structure

This paper contains 25 sections, 4 theorems, 25 equations, 7 figures, 11 tables, 2 algorithms.

Key Result

Proposition 4.0

Assmue $\mathcal{L}_{CE}$ is $K$-smooth, twice differentiable, and locally $J$-strongly convex in $\theta$ around $\{\theta_{\iota+1},..., \theta_N \}$. Let $\Xi(\theta, \phi) = \theta - \alpha \nabla\mathcal{L}_{CE}(\theta, \phi)$. For $\alpha \leq \frac{1}{K}$, then where $\frac{\partial F(\phi)}{\partial \phi}$ denotes the untruncated hypergradient.

Figures (7)

  • Figure 1: Test accuracy (%) comparison on continual learning. Left: 5-step class-incremental learning on Cifar10 50IPC, Middle: 10-step class-incremental learning on Tiny ImageNet 3IPC, Right: 20-step class-incremental learning on Tiny ImageNet 3IPC.
  • Figure 2: Sharpness analysis by visualisation. Hypergradient Norm comparison between MTT and SATM. Top: the hypergradient norm on Cifar100 with 10 IPC; Middle: the hypergradient norm on Tiny ImageNet with 3 IPC. Bottom: Sharpness dynamic on Tiny ImageNet with 3 IPC.
  • Figure 3: The comparison of the learning dynamic of learning rate learning with first and second order differentiation when condensing on the Cifar100-10IPC setting.
  • Figure 4: Cifar10 with 1IPC
  • Figure 5: Cifar10 with 3IPC
  • ...and 2 more figures

Theorems & Definitions (6)

  • Proposition 4.0
  • Theorem 4.1
  • Theorem 1.1
  • proof
  • Proposition 1.0
  • proof