Table of Contents
Fetching ...

Towards Mitigating Architecture Overfitting on Distilled Datasets

Xuyang Zhong, Chen Liu

TL;DR

This work tackles architecture overfitting in dataset distillation, where distilled data crafted for a specific training architecture fails to generalize to other architectures. It introduces smoothing-driven methods that treat the larger test model as an implicit ensemble of sub-networks via DropPath and regularize sub-networks with knowledge distillation from a smaller teacher, complemented by three-phase keep-rate scheduling, improved shortcuts, and stronger augmentation. Across multiple dataset distillation methods and datasets, the approach significantly reduces cross-architecture gaps and even yields superior performance when test networks are larger than the training network. The findings highlight improved transferability of distilled datasets and suggest practical benefits for training with limited real data as well. Overall, the paper advances cross-architecture robustness in dataset distillation through plug-and-play, smoothing-based techniques with broad applicability.

Abstract

Dataset distillation methods have demonstrated remarkable performance for neural networks trained with very limited training data. However, a significant challenge arises in the form of \textit{architecture overfitting}: the distilled training dataset synthesized by a specific network architecture (i.e., training network) generates poor performance when trained by other network architectures (i.e., test networks), especially when the test networks have a larger capacity than the training network. This paper introduces a series of approaches to mitigate this issue. Among them, DropPath renders the large model to be an implicit ensemble of its sub-networks, and knowledge distillation ensures each sub-network acts similarly to the small but well-performing teacher network. These methods, characterized by their smoothing effects, significantly mitigate architecture overfitting. We conduct extensive experiments to demonstrate the effectiveness and generality of our methods. Particularly, across various scenarios involving different tasks and different sizes of distilled data, our approaches significantly mitigate architecture overfitting. Furthermore, our approaches achieve comparable or even superior performance when the test network is larger than the training network.

Towards Mitigating Architecture Overfitting on Distilled Datasets

TL;DR

This work tackles architecture overfitting in dataset distillation, where distilled data crafted for a specific training architecture fails to generalize to other architectures. It introduces smoothing-driven methods that treat the larger test model as an implicit ensemble of sub-networks via DropPath and regularize sub-networks with knowledge distillation from a smaller teacher, complemented by three-phase keep-rate scheduling, improved shortcuts, and stronger augmentation. Across multiple dataset distillation methods and datasets, the approach significantly reduces cross-architecture gaps and even yields superior performance when test networks are larger than the training network. The findings highlight improved transferability of distilled datasets and suggest practical benefits for training with limited real data as well. Overall, the paper advances cross-architecture robustness in dataset distillation through plug-and-play, smoothing-based techniques with broad applicability.

Abstract

Dataset distillation methods have demonstrated remarkable performance for neural networks trained with very limited training data. However, a significant challenge arises in the form of \textit{architecture overfitting}: the distilled training dataset synthesized by a specific network architecture (i.e., training network) generates poor performance when trained by other network architectures (i.e., test networks), especially when the test networks have a larger capacity than the training network. This paper introduces a series of approaches to mitigate this issue. Among them, DropPath renders the large model to be an implicit ensemble of its sub-networks, and knowledge distillation ensures each sub-network acts similarly to the small but well-performing teacher network. These methods, characterized by their smoothing effects, significantly mitigate architecture overfitting. We conduct extensive experiments to demonstrate the effectiveness and generality of our methods. Particularly, across various scenarios involving different tasks and different sizes of distilled data, our approaches significantly mitigate architecture overfitting. Furthermore, our approaches achieve comparable or even superior performance when the test network is larger than the training network.
Paper Structure (17 sections, 5 equations, 6 figures, 8 tables, 1 algorithm)

This paper contains 17 sections, 5 equations, 6 figures, 8 tables, 1 algorithm.

Figures (6)

  • Figure 1: Effectiveness of our method on different architectures, different dataset distillation methods, and different images per class (IPCs) on CIFAR10. We use a 3-layer CNN as the training network, so it performs the best among various architectures in baselines (dashed lines). Our methods (solid lines) can significantly narrow down the performance gap between the 3-layer CNN and other architectures.
  • Figure 2: (a) The DropPath used for multi-branch residual blocks during training, it does not block the shortcut path. (b) The DropPath used for single-branch networks during training. Here, $m=\mathtt{Bernoulli}(p) \in \{0,1\}$, $p\in[0,1]$ denotes the keep rate. Only when the main path is pruned ($m = 0$), the virtual shortcut is activated, and vice versa. DropPath is always deactivated, i.e., $p=1$, during inference. (c) The original architecture of a shortcut connection to downsample feature maps, which consists of a $1\times1$ convolution layer with the stride of $2$ and a normalization layer. (d) The improved architecture of a shortcut connection to downsample feature maps, which is a sequence of a $2 \times 2$ max pooling layer, a $1 \times 1$ convolution layer with the stride of $1$ and a normalization layer.
  • Figure 3: Test accuracies obtained from training on different fractions of CIFAR10, the shadow indicates the standard deviation. We compare the test accuracies (a) between ResNet18 (RN18) and 3-layer CNN (CNN), (b) between ResNet50 (RN50) and CNN, (c) between VGG11 and 3-layer CNN (CNN), and (d) between ResNet50 (RN50) and ResNet18, respectively. The x-axis denotes the fraction of training data, DP+KD denotes that the network is trained with DropPath and knowledge distillation. The model enclosed in the brackets after KD represents the teacher model used. Note that we run the experiments three times with different random seeds.
  • Figure 4: Visualization of the smoothing effect induced by proposed methods. (a) Top 20 eigenvalues of Hessian matrix for ResNet18 trained with different settings, including full setting with Lion optimizer (Full w/ L), full setting with AdamW (Full w/ A), w/o DP, w/o KD and w/o DP&KD. For w/o DP, w/o KD and w/o DP&KD, Lion is adopted by default. (b)-(f) Loss landscape $\mathcal{L}_{CE}(\theta + \alpha_1\mathbf{v}_1 + \alpha_2\mathbf{v}_2)$ of ResNet18 around the minima found by models with different settings, where $\mathbf{v}_1$ and $\mathbf{v}_2$ are the eigenvectors corresponding to the top two eigenvalues of Hessian matrices, respectively. Note that the training data is 100 (IPC=10) distilled images of CIFAR10 by FRePo. ResNet18 is trained with DropPath and knowledge distillation. 3-layer CNN serves as the teacher model where knowledge distillation is adopted.
  • Figure 5: Ablation studies on minimum keep rate, final keep rate, period of decay, weight and temperature of knowledge distillation (KD). (a) Test accuracies of different minimum keep rates. (b) Test accuracies of different keep rates at the final phase. (c) Test accuracies of different periods of decay. (d) Test accuracies of different KD weights. (e) Test accuracies of different KD temperatures. Regardless of the variation of hyperparameters, ResNet18 trained with our approach generally outperforms 3-layer CNN trained with baseline (orange dashed line) and that trained with better optimization and data augmentation (green dashed line).
  • ...and 1 more figures