Algorithmic Guarantees for Distilling Supervised and Offline RL Datasets
Aaryan Gupta, Rishi Saket, Aravindan Raghuveer
TL;DR
This work develops a provable, training-free dataset distillation framework for both supervised regression and offline RL. For linear regression in $\mathbb{R}^d$, it shows that $\tilde{O}(d^2)$ randomly sampled regressors suffice to ensure the distilled synthetic data yields MSE losses nearly identical to those on the full training data, with a matching $\Omega(d^2)$ lower bound. It extends the approach to offline RL by distilling via the Bellman loss using random linear $Q$-value predictors; under decomposable state-action embeddings, the sample complexity improves to $\tilde{O}(d^2)$ and the optimization can be convex. Experiments validate the theory, demonstrating improved performance and data efficiency over baselines in both supervised and offline RL settings, including non-linear learners in practice.
Abstract
Given a training dataset, the goal of dataset distillation is to derive a synthetic dataset such that models trained on the latter perform as well as those trained on the training dataset. In this work, we develop and analyze an efficient dataset distillation algorithm for supervised learning, specifically regression in $\mathbb{R}^d$, based on matching the losses on the training and synthetic datasets with respect to a fixed set of randomly sampled regressors without any model training. Our first key contribution is a novel performance guarantee proving that our algorithm needs only $\tilde{O}(d^2)$ sampled regressors to derive a synthetic dataset on which the MSE loss of any bounded linear model is nearly the same as its MSE loss on the given training data. In particular, the model optimized on the synthetic data has close to minimum loss on the training data, thus performing nearly as well as the model optimized on the latter. Complementing this, we also prove a matching lower bound of $Ω(d^2)$ for the number of sampled regressors showing the tightness of our analysis. Our second contribution is to extend our algorithm to offline RL dataset distillation by matching the Bellman loss, unlike previous works which used a behavioral cloning objective. This is the first such method which leverages both, the rewards and the next state information, available in offline RL datasets, without any policy model optimization. Our algorithm generates a synthetic dataset whose Bellman loss with respect to any linear action-value predictor is close to the latter's Bellman loss on the offline RL training dataset. Therefore, a policy associated with an action-value predictor optimized on the synthetic dataset performs nearly as well as that derived from the one optimized on the training data. We conduct experiments to validate our theoretical guarantees and observe performance gains.
