Table of Contents
Fetching ...

Online Importance Sampling for Stochastic Gradient Optimization

Corentin Salaün, Xingchang Huang, Iliyan Georgiev, Niloy J. Mitra, Gurprit Singh

TL;DR

This work tackles the慢gradient variance problem in stochastic gradient optimization by introducing a practical online importance sampling framework that computes per-sample importance on-the-fly, removing the need for expensive dataset preprocessing. It proposes a novel metric based on the loss derivative at the network output to prioritize influential data points and to guide both sampling and online data pruning. The method uses a lightweight online algorithm that maintains persistent unnormalized importance scores, reuses forward-pass computations, and optionally prunes low-importance samples during training, yielding faster convergence with minimal accuracy loss. Empirical results across classification and regression tasks demonstrate improved gradient estimation and training efficiency, with consistent gains over DLIS and LOW baselines and effective data pruning strategies that reduce training time while preserving performance. The approach offers a practical and scalable path to faster training in large-scale settings without heavy preprocessing.

Abstract

Machine learning optimization often depends on stochastic gradient descent, where the precision of gradient estimation is vital for model performance. Gradients are calculated from mini-batches formed by uniformly selecting data samples from the training dataset. However, not all data samples contribute equally to gradient estimation. To address this, various importance sampling strategies have been developed to prioritize more significant samples. Despite these advancements, all current importance sampling methods encounter challenges related to computational efficiency and seamless integration into practical machine learning pipelines. In this work, we propose a practical algorithm that efficiently computes data importance on-the-fly during training, eliminating the need for dataset preprocessing. We also introduce a novel metric based on the derivative of the loss w.r.t. the network output, designed for mini-batch importance sampling. Our metric prioritizes influential data points, thereby enhancing gradient estimation accuracy. We demonstrate the effectiveness of our approach across various applications. We first perform classification and regression tasks to demonstrate improvements in accuracy. Then, we show how our approach can also be used for online data pruning by identifying and discarding data samples that contribute minimally towards the training loss. This significantly reduce training time with negligible loss in the accuracy of the model.

Online Importance Sampling for Stochastic Gradient Optimization

TL;DR

This work tackles the慢gradient variance problem in stochastic gradient optimization by introducing a practical online importance sampling framework that computes per-sample importance on-the-fly, removing the need for expensive dataset preprocessing. It proposes a novel metric based on the loss derivative at the network output to prioritize influential data points and to guide both sampling and online data pruning. The method uses a lightweight online algorithm that maintains persistent unnormalized importance scores, reuses forward-pass computations, and optionally prunes low-importance samples during training, yielding faster convergence with minimal accuracy loss. Empirical results across classification and regression tasks demonstrate improved gradient estimation and training efficiency, with consistent gains over DLIS and LOW baselines and effective data pruning strategies that reduce training time while preserving performance. The approach offers a practical and scalable path to faster training in large-scale settings without heavy preprocessing.

Abstract

Machine learning optimization often depends on stochastic gradient descent, where the precision of gradient estimation is vital for model performance. Gradients are calculated from mini-batches formed by uniformly selecting data samples from the training dataset. However, not all data samples contribute equally to gradient estimation. To address this, various importance sampling strategies have been developed to prioritize more significant samples. Despite these advancements, all current importance sampling methods encounter challenges related to computational efficiency and seamless integration into practical machine learning pipelines. In this work, we propose a practical algorithm that efficiently computes data importance on-the-fly during training, eliminating the need for dataset preprocessing. We also introduce a novel metric based on the derivative of the loss w.r.t. the network output, designed for mini-batch importance sampling. Our metric prioritizes influential data points, thereby enhancing gradient estimation accuracy. We demonstrate the effectiveness of our approach across various applications. We first perform classification and regression tasks to demonstrate improvements in accuracy. Then, we show how our approach can also be used for online data pruning by identifying and discarding data samples that contribute minimally towards the training loss. This significantly reduce training time with negligible loss in the accuracy of the model.
Paper Structure (33 sections, 11 equations, 8 figures, 3 tables, 2 algorithms)

This paper contains 33 sections, 11 equations, 8 figures, 3 tables, 2 algorithms.

Figures (8)

  • Figure 1: Visualization of the importance sampling at 3 different epoch and the underlying classification task. For each presented epoch, 800 data-point are presented with a transparency proportional to their weight according to our method. At epoch 800 our weights show high similarity to DLIS method while in practice some discrepancy differentiate the two method but are not visible in this simple example.
  • Figure 2: Evolution of gradient variance for variance importance sampling strategies on polynomial regression and MNIST classification task. In both case the optimization is done on a 3 fully-connected layer network. Variance estimation is made of each method on the same network at each epoch. The variance is computed using a mini-batch of size 16. Computation time for each metric can be found in \ref{['app:Additional_results']}\ref{['tab:Computation_time']}
  • Figure 3: Evaluation of the impact of the amount of data pruned during training on a MNIST classification task. The left panel shows the evolution of the pruned data over time, while the right panel presents the final accuracy, the average training set size during training and remaining data at the end of training, the total training time, and the computation time of pruning. The figure compares a uniform sampling without data pruning, random pruning with $60\%$, $43\%$, and $35\%$ of data pruned, the method of yang2022dataset at the same pruning rates, and our approach using a dynamic reduction factor $K$. Results indicate that pruning more data accelerates execution. Our online pruning method offers greater adaptability during training while maintaining high accuracy and minimal difference between training time and total execution time.
  • Figure 4: Comparison at equal step for image 2D regression. The left side shows the convergence plot while the right display the absolute error of the regression and a close-up view. Our method using data pruning achieves the lower error on this problem while pruning $45\%$ of the data during training. Our method using only importance sampling and DLIS with our algorithm perform similarly, but DLIS with their full method perform worse than default optimization. In the images it is visible that our method with pruning recovers the finest details of the fur and whiskers.
  • Figure 5: Comparisons on CIFAR-10 using Vision Transformer (ViT) dosovitskiy2020image. The results show consistent improvement of Ours IS and Ours IS + Data pruning over LOW santiago2021low and DLIS katharopoulos2018dlis for both equal epoch and equal time.
  • ...and 3 more figures