Table of Contents
Fetching ...

Joint control variate for faster black-box variational inference

Xi Wang, Tomas Geffner, Justin Domke

TL;DR

A new joint control variate is proposed that jointly reduces variance from both sources of noise, leading to faster optimization in several applications.

Abstract

Black-box variational inference performance is sometimes hindered by the use of gradient estimators with high variance. This variance comes from two sources of randomness: Data subsampling and Monte Carlo sampling. While existing control variates only address Monte Carlo noise, and incremental gradient methods typically only address data subsampling, we propose a new "joint" control variate that jointly reduces variance from both sources of noise. This significantly reduces gradient variance, leading to faster optimization in several applications.

Joint control variate for faster black-box variational inference

TL;DR

A new joint control variate is proposed that jointly reduces variance from both sources of noise, leading to faster optimization in several applications.

Abstract

Black-box variational inference performance is sometimes hindered by the use of gradient estimators with high variance. This variance comes from two sources of randomness: Data subsampling and Monte Carlo sampling. While existing control variates only address Monte Carlo noise, and incremental gradient methods typically only address data subsampling, we propose a new "joint" control variate that jointly reduces variance from both sources of noise. This significantly reduces gradient variance, leading to faster optimization in several applications.
Paper Structure (28 sections, 48 equations, 9 figures, 3 tables)

This paper contains 28 sections, 48 equations, 9 figures, 3 tables.

Figures (9)

  • Figure 1: The contributions of subsampling and Monte Carlo noise vary by problem. The proposed joint estimator reduces both. Orange lines denote variance from data subsampling ($n$), and green lines denote Monte Carlo noise variance ($\epsilon$). We use a batch size of 5. For the Sonar dataset, both sources show similar scales. For the Australian dataset, subsampling noise dominates. Regardless, our proposed gradient estimator $g_\mathtt{joint}$ (red line, Eq. \ref{['eq:g_dual']}) mitigates subsampling noise and controls MC noise, aligning closely with or below green lines (i.e. the variance without data subsampling) in both datasets.
  • Figure 2: In practice, cv and inc reduce variance nearly as much as theoretically possible. The joint estimator variance is lower than these bounds. The $\mathtt{naive}$ gradient estimator (Eq. \ref{['eq:g_naive_n_eps']}) is the baseline, while the $\mathtt{cv}$ estimator (Eq. \ref{['eq:g_cv_w_n_eps']}) controls the Monte Carlo noise, the $\mathtt{inc}$ estimator (Eq. \ref{['eq:g_inc_w_n_eps']}) controls for subsampling noise, and the proposed $\mathtt{joint}$ estimator (Eq. \ref{['eq:g_dual']}) controls for both. The variance of $\mathtt{cv}$ and $\mathtt{inc}$, as is shown in Eq. \ref{['eq:g_cv_var_final']} and Eq. \ref{['eq:g_inc_var_final']} are lower-bounded by the dotted lines, while $\mathtt{joint}$ is capable of reducing the variance to significantly lower values, leading to better and faster convergence (first two grids in Fig. \ref{['fig:big_results']}).
  • Figure 3: On various tasks, the proposed $\mathtt{joint}$ control variate leads to faster convergence through controlling both Monte Carlo and subsampling noise. Compared to the $\mathtt{naive}$ estimator, $\mathtt{cv}$ controls only Monte Carlo noise, while $\mathtt{inc}$ and SMISO control only subsampling noise. Our proposed $\mathtt{joint}$ estimator converges faster than $\mathtt{naive}$ and $\mathtt{cv}$ on all tasks. The step sizes for SMISO are rescaled for each model for visualization. On PPCA and MovieLens, SMISO has not converged enough to appear, see Fig. \ref{['fig:sgd_vs_adam']} in Appendix. \ref{['sec:sgd_results']} for full results. In Tennis, there is periodic behavior for many estimators as gradients have correlated noise that cancels out at the end of each epoch—the $\mathtt{joint}$ estimator largely cancels this. All lines presented the average of multiple trials (5 for Sonar and Australian, 10 for the rest), with shaded areas showing one standard deviation.
  • Figure 4: The SVRG version of $\mathtt{joint}$ shows performance similar to the SAGA version on Australian. The origin version of SAGA-based $\mathtt{joint}$ control variate requires $O(ND)$ memory cost. It is possible to alleviate the additional memory cost by using the SVRG version of $\mathtt{joint}$, which costs no extra memory but would require extra gradient evaluation at each step. In the experiments above, we update the SVRG cache every 1 epochs, equivalent to 1 extra gradient evaluation per iteration. Overall, we observe $\mathtt{joint}$ (svrg) showing results similar to the saga version of $\mathtt{joint}$.
  • Figure 5: The joint estimator leads to improved convergence at higher learning rates on Gaussian dropout on CIFAR-10. (For small enough learning rates, optimization speed is limited by the learning rate itself and so all estimators perform identically.) The first column shows the trace of the objective (logistic loss) under a learning rate of $0.05$. The second column shows the trace of the objective under the best learning rate chosen retrospectively at each iteration. The final two columns show the objective as a function of different learning rates at two different numbers of iterations. Note that the learning that the $\mathtt{joint}$ has its best performance at a higher learning rate than the other estimators. (the $\mathtt{inc}$ estimator is too expensive to be included here.)
  • ...and 4 more figures