Enhancing Generalization via Sharpness-Aware Trajectory Matching for Dataset Condensation
Boyan Gao, Bo Zhao, Shreyank N Gowda, Xingrun Xing, Yibo Yang, Timothy Hospedales, David A. Clifton
TL;DR
The paper tackles dataset condensation with long-horizon bilevel optimization, where learning synthetic data via trajectory matching often generalizes poorly and incurs heavy compute. It introduces Sharpness-Aware Trajectory Matching (SATM), which jointly minimizes outer-loop loss sharpness and distances between training trajectories, and pairs it with efficient hypergradient tricks: Truncated Unrolling Hypergradient (TUH) and Trajectory Reusing (TR), plus a smoothing mechanism that uses Gaussian perturbations. A closed-form learning-rate update further reduces backpropagation costs. Empirical results across CIFAR-10/100, Tiny ImageNet, and ImageNet subsets demonstrate improved in-domain and out-of-domain generalization with favorable time/memory trade-offs, including strong cross-architecture performance and continual learning gains, highlighting SATM’s practical impact for scalable dataset condensation.
Abstract
Dataset condensation aims to synthesize datasets with a few representative samples that can effectively represent the original datasets. This enables efficient training and produces models with performance close to those trained on the original sets. Most existing dataset condensation methods conduct dataset learning under the bilevel (inner- and outer-loop) based optimization. However, the preceding methods perform with limited dataset generalization due to the notoriously complicated loss landscape and expensive time-space complexity of the inner-loop unrolling of bilevel optimization. These issues deteriorate when the datasets are learned via matching the trajectories of networks trained on the real and synthetic datasets with a long horizon inner-loop. To address these issues, we introduce Sharpness-Aware Trajectory Matching (SATM), which enhances the generalization capability of learned synthetic datasets by optimising the sharpness of the loss landscape and objective simultaneously. Moreover, our approach is coupled with an efficient hypergradient approximation that is mathematically well-supported and straightforward to implement along with controllable computational overhead. Empirical evaluations of SATM demonstrate its effectiveness across various applications, including in-domain benchmarks and out-of-domain settings. Moreover, its easy-to-implement properties afford flexibility, allowing it to integrate with other advanced sharpness-aware minimizers. Our code will be released.
