Online Importance Sampling for Stochastic Gradient Optimization
Corentin Salaün, Xingchang Huang, Iliyan Georgiev, Niloy J. Mitra, Gurprit Singh
TL;DR
This work tackles the慢gradient variance problem in stochastic gradient optimization by introducing a practical online importance sampling framework that computes per-sample importance on-the-fly, removing the need for expensive dataset preprocessing. It proposes a novel metric based on the loss derivative at the network output to prioritize influential data points and to guide both sampling and online data pruning. The method uses a lightweight online algorithm that maintains persistent unnormalized importance scores, reuses forward-pass computations, and optionally prunes low-importance samples during training, yielding faster convergence with minimal accuracy loss. Empirical results across classification and regression tasks demonstrate improved gradient estimation and training efficiency, with consistent gains over DLIS and LOW baselines and effective data pruning strategies that reduce training time while preserving performance. The approach offers a practical and scalable path to faster training in large-scale settings without heavy preprocessing.
Abstract
Machine learning optimization often depends on stochastic gradient descent, where the precision of gradient estimation is vital for model performance. Gradients are calculated from mini-batches formed by uniformly selecting data samples from the training dataset. However, not all data samples contribute equally to gradient estimation. To address this, various importance sampling strategies have been developed to prioritize more significant samples. Despite these advancements, all current importance sampling methods encounter challenges related to computational efficiency and seamless integration into practical machine learning pipelines. In this work, we propose a practical algorithm that efficiently computes data importance on-the-fly during training, eliminating the need for dataset preprocessing. We also introduce a novel metric based on the derivative of the loss w.r.t. the network output, designed for mini-batch importance sampling. Our metric prioritizes influential data points, thereby enhancing gradient estimation accuracy. We demonstrate the effectiveness of our approach across various applications. We first perform classification and regression tasks to demonstrate improvements in accuracy. Then, we show how our approach can also be used for online data pruning by identifying and discarding data samples that contribute minimally towards the training loss. This significantly reduce training time with negligible loss in the accuracy of the model.
