Dataset Distillation via Knowledge Distillation: Towards Efficient Self-Supervised Pre-Training of Deep Networks
Siddharth Joshi, Jiayi Ni, Baharan Mirzasoleiman
TL;DR
The paper tackles SSL pre-training with dataset distillation, identifying that naive GD/trajectory methods from supervised DD fail due to high gradient variance. It introduces MKDT, a KD-based trajectory-matching approach that trains a teacher (SSL) and multiple student encoders to produce low-variance expert trajectories, then distills synthetic data by aligning training trajectories with these KD trajectories. Empirically, MKDT yields up to 13% improvements in downstream linear-probe accuracy over strong baselines on CIFAR10/100 and TinyImageNet, and generalizes to larger architectures and multiple SSL frameworks (BarlowTwins, SimCLR). The work demonstrates that KD can enable data-efficient SSL pre-training, reducing data and compute requirements while preserving representation quality, with code available for reproduction.
Abstract
Dataset distillation (DD) generates small synthetic datasets that can efficiently train deep networks with a limited amount of memory and compute. Despite the success of DD methods for supervised learning, DD for self-supervised pre-training of deep models has remained unaddressed. Pre-training on unlabeled data is crucial for efficiently generalizing to downstream tasks with limited labeled data. In this work, we propose the first effective DD method for SSL pre-training. First, we show, theoretically and empirically, that naive application of supervised DD methods to SSL fails, due to the high variance of the SSL gradient. Then, we address this issue by relying on insights from knowledge distillation (KD) literature. Specifically, we train a small student model to match the representations of a larger teacher model trained with SSL. Then, we generate a small synthetic dataset by matching the training trajectories of the student models. As the KD objective has considerably lower variance than SSL, our approach can generate synthetic datasets that can successfully pre-train high-quality encoders. Through extensive experiments, we show that our distilled sets lead to up to 13% higher accuracy than prior work, on a variety of downstream tasks, in the presence of limited labeled data. Code at https://github.com/BigML-CS-UCLA/MKDT.
