Dataset Distillation for Quantum Neural Networks
Koustubh Phalak, Junde Li, Swaroop Ghosh
TL;DR
The paper tackles the high cost of training quantum neural networks (QNNs) on large classical datasets by applying dataset distillation to a quantum LeNet model. It introduces a quantum-classical LeNet variant with residual connections and a trainable Hermitian observable in the parametric quantum circuit, enabling a compact synthetic dataset that preserves inference performance. Empirical results show post-distillation accuracy of 91.9% on MNIST and 50.3% on CIFAR-10 for the quantum model, closely approaching the classical LeNet baselines of 94% and 54%, with stability improvements from a non-trainable Hermitian. The work highlights practical trade-offs and sets the stage for enhanced quantum feature extraction and more robust distillation methods in QNNs.
Abstract
Training Quantum Neural Networks (QNNs) on large amount of classical data can be both time consuming as well as expensive. Higher amount of training data would require higher number of gradient descent steps to reach convergence. This, in turn would imply that the QNN will require higher number of quantum executions, thereby driving up its overall execution cost. In this work, we propose performing the dataset distillation process for QNNs, where we use a novel quantum variant of classical LeNet model containing residual connection and trainable Hermitian observable in the Parametric Quantum Circuit (PQC) of the QNN. This approach yields highly informative yet small number of training data at similar performance as the original data. We perform distillation for MNIST and Cifar-10 datasets, and on comparison with classical models observe that both the datasets yield reasonably similar post-inferencing accuracy on quantum LeNet (91.9% MNIST, 50.3% Cifar-10) compared to classical LeNet (94% MNIST, 54% Cifar-10). We also introduce a non-trainable Hermitian for ensuring stability in the distillation process and note marginal reduction of up to 1.8% (1.3%) for MNIST (Cifar-10) dataset.
