Table of Contents
Fetching ...

Dataset Distillation via Knowledge Distillation: Towards Efficient Self-Supervised Pre-Training of Deep Networks

Siddharth Joshi, Jiayi Ni, Baharan Mirzasoleiman

TL;DR

The paper tackles SSL pre-training with dataset distillation, identifying that naive GD/trajectory methods from supervised DD fail due to high gradient variance. It introduces MKDT, a KD-based trajectory-matching approach that trains a teacher (SSL) and multiple student encoders to produce low-variance expert trajectories, then distills synthetic data by aligning training trajectories with these KD trajectories. Empirically, MKDT yields up to 13% improvements in downstream linear-probe accuracy over strong baselines on CIFAR10/100 and TinyImageNet, and generalizes to larger architectures and multiple SSL frameworks (BarlowTwins, SimCLR). The work demonstrates that KD can enable data-efficient SSL pre-training, reducing data and compute requirements while preserving representation quality, with code available for reproduction.

Abstract

Dataset distillation (DD) generates small synthetic datasets that can efficiently train deep networks with a limited amount of memory and compute. Despite the success of DD methods for supervised learning, DD for self-supervised pre-training of deep models has remained unaddressed. Pre-training on unlabeled data is crucial for efficiently generalizing to downstream tasks with limited labeled data. In this work, we propose the first effective DD method for SSL pre-training. First, we show, theoretically and empirically, that naive application of supervised DD methods to SSL fails, due to the high variance of the SSL gradient. Then, we address this issue by relying on insights from knowledge distillation (KD) literature. Specifically, we train a small student model to match the representations of a larger teacher model trained with SSL. Then, we generate a small synthetic dataset by matching the training trajectories of the student models. As the KD objective has considerably lower variance than SSL, our approach can generate synthetic datasets that can successfully pre-train high-quality encoders. Through extensive experiments, we show that our distilled sets lead to up to 13% higher accuracy than prior work, on a variety of downstream tasks, in the presence of limited labeled data. Code at https://github.com/BigML-CS-UCLA/MKDT.

Dataset Distillation via Knowledge Distillation: Towards Efficient Self-Supervised Pre-Training of Deep Networks

TL;DR

The paper tackles SSL pre-training with dataset distillation, identifying that naive GD/trajectory methods from supervised DD fail due to high gradient variance. It introduces MKDT, a KD-based trajectory-matching approach that trains a teacher (SSL) and multiple student encoders to produce low-variance expert trajectories, then distills synthetic data by aligning training trajectories with these KD trajectories. Empirically, MKDT yields up to 13% improvements in downstream linear-probe accuracy over strong baselines on CIFAR10/100 and TinyImageNet, and generalizes to larger architectures and multiple SSL frameworks (BarlowTwins, SimCLR). The work demonstrates that KD can enable data-efficient SSL pre-training, reducing data and compute requirements while preserving representation quality, with code available for reproduction.

Abstract

Dataset distillation (DD) generates small synthetic datasets that can efficiently train deep networks with a limited amount of memory and compute. Despite the success of DD methods for supervised learning, DD for self-supervised pre-training of deep models has remained unaddressed. Pre-training on unlabeled data is crucial for efficiently generalizing to downstream tasks with limited labeled data. In this work, we propose the first effective DD method for SSL pre-training. First, we show, theoretically and empirically, that naive application of supervised DD methods to SSL fails, due to the high variance of the SSL gradient. Then, we address this issue by relying on insights from knowledge distillation (KD) literature. Specifically, we train a small student model to match the representations of a larger teacher model trained with SSL. Then, we generate a small synthetic dataset by matching the training trajectories of the student models. As the KD objective has considerably lower variance than SSL, our approach can generate synthetic datasets that can successfully pre-train high-quality encoders. Through extensive experiments, we show that our distilled sets lead to up to 13% higher accuracy than prior work, on a variety of downstream tasks, in the presence of limited labeled data. Code at https://github.com/BigML-CS-UCLA/MKDT.
Paper Structure (20 sections, 2 theorems, 49 equations, 4 figures, 12 tables, 1 algorithm)

This paper contains 20 sections, 2 theorems, 49 equations, 4 figures, 12 tables, 1 algorithm.

Key Result

Theorem 4.1

Let $D = \{(x_i, y_i)\}_{i=1}^n$ be a dataset with $n$ examples, where $x_i$ is the $i$-th input and $y_i \in \{0, 1\}$ is its corresponding class label. Assume the data $x_i$ are generated using the sparse coding model xue2023featurespmlr-v238-joshi24a: for class 0, $x_i = e_0 + \epsilon_i$, and fo where $e_{y_i}$ is the one-hot encoded vector for class $y_i$, and $B$ is a mini-batch. The SSL Los

Figures (4)

  • Figure 1: Challenges of MTT for SSL (Dataset: CIFAR100 (1%); Arch: 3-layer ConvNet)
  • Figure 2: Examples of Distilled Images for CIFAR10
  • Figure 3: Examples of Distilled Images for CIFAR100
  • Figure 4: Examples of Distilled Images for TinyImageNet

Theorems & Definitions (4)

  • Theorem 4.1
  • proof
  • Proposition D.1
  • proof