Table of Contents
Fetching ...

Multiple Importance Sampling for Stochastic Gradient Estimation

Corentin Salaün, Xingchang Huang, Iliyan Georgiev, Niloy J. Mitra, Gurprit Singh

TL;DR

This work tackles high-variance gradient estimation in SGD by introducing a self-adaptive importance sampling framework that dynamically evolves the sampling distribution during training. It extends importance sampling to vector-valued gradient estimation through multiple importance sampling (MIS) and optimal MIS (OptiMIS), enabling jointly weighted gradient contributions from multiple distributions without resampling. The authors propose practical algorithms (IS and OMIS) with momentum-based stabilization and gradient-based importance functions, achieving faster convergence on classification, regression, and point-cloud tasks. The approach yields improved gradient estimates with manageable overhead and is demonstrated to approach or match exact-gradient performance in controlled experiments, suggesting broad applicability for efficient optimization in neural networks.

Abstract

We introduce a theoretical and practical framework for efficient importance sampling of mini-batch samples for gradient estimation from single and multiple probability distributions. To handle noisy gradients, our framework dynamically evolves the importance distribution during training by utilizing a self-adaptive metric. Our framework combines multiple, diverse sampling distributions, each tailored to specific parameter gradients. This approach facilitates the importance sampling of vector-valued gradient estimation. Rather than naively combining multiple distributions, our framework involves optimally weighting data contribution across multiple distributions. This adapted combination of multiple importance yields superior gradient estimates, leading to faster training convergence. We demonstrate the effectiveness of our approach through empirical evaluations across a range of optimization tasks like classification and regression on both image and point cloud datasets.

Multiple Importance Sampling for Stochastic Gradient Estimation

TL;DR

This work tackles high-variance gradient estimation in SGD by introducing a self-adaptive importance sampling framework that dynamically evolves the sampling distribution during training. It extends importance sampling to vector-valued gradient estimation through multiple importance sampling (MIS) and optimal MIS (OptiMIS), enabling jointly weighted gradient contributions from multiple distributions without resampling. The authors propose practical algorithms (IS and OMIS) with momentum-based stabilization and gradient-based importance functions, achieving faster convergence on classification, regression, and point-cloud tasks. The approach yields improved gradient estimates with manageable overhead and is demonstrated to approach or match exact-gradient performance in controlled experiments, suggesting broad applicability for efficient optimization in neural networks.

Abstract

We introduce a theoretical and practical framework for efficient importance sampling of mini-batch samples for gradient estimation from single and multiple probability distributions. To handle noisy gradients, our framework dynamically evolves the importance distribution during training by utilizing a self-adaptive metric. Our framework combines multiple, diverse sampling distributions, each tailored to specific parameter gradients. This approach facilitates the importance sampling of vector-valued gradient estimation. Rather than naively combining multiple distributions, our framework involves optimally weighting data contribution across multiple distributions. This adapted combination of multiple importance yields superior gradient estimates, leading to faster training convergence. We demonstrate the effectiveness of our approach through empirical evaluations across a range of optimization tasks like classification and regression on both image and point cloud datasets.
Paper Structure (33 sections, 17 equations, 12 figures, 6 algorithms)

This paper contains 33 sections, 17 equations, 12 figures, 6 algorithms.

Figures (12)

  • Figure 1: We visualize different importance sampling distributions for a simple classification task. We propose to use the output layer gradients for importance sampling, as shown in the network diagram (a). For a given ground-truth classification (top) and training dataset (bottom) shown in (b), it is possible to importance sample from the $L_2$ norm of the output-layer gradients (c) or from three different sampling distributions derived from the gradient norms of individual output nodes (d). The bottom row shows sample weights from each distribution.
  • Figure 2: Convergence comparison of polynomial regression of order 6 using different method. Exact gradient show a gradient descent as baseline and classical SGD. For our method, we compare importance sampling and OMIS using $n=2$ or $4$ importance distributions. Balance heuristic MIS is also visible. Our method using OMIS achieve same convergence as exact gradient.
  • Figure 3: Classification error convergence for MNIST classification for various methods. Both katharopoulos2018dlis (DLIS) and resampling SGD approach. In comparison, our two method use the presented algorithm without resampling. It is visible that while DLIS perform similarly to our IS at equal epoch, the overhead of the method makes ours noticeably better at equal time for our IS and OMIS.
  • Figure 4: On CIFAR-100, we use the DLIS importance metric in our \ref{['alg:IS']} instead of the DLIS resampling algorithm. The zoom-in highlights show error drops when the learning rate decreases after epoch 100. Our method (Our IS) outperforms LOW santiago2021low and DLIS weights at equal epochs (left). It also converges faster than LOW and DLIS weights at equal time (right).
  • Figure 5: Comparisons on CIFAR-10 using Vision Transformer (ViT) dosovitskiy2020image. The results show our importance sampling scheme (Our IS) can improve over classical SGD, LOW santiago2021low and DLIS katharopoulos2018dlis on modern transformer architecture.
  • ...and 7 more figures