Table of Contents
Fetching ...

Sampling from Gaussian Process Posteriors using Stochastic Gradient Descent

Jihao Andreas Lin, Javier Antorán, Shreyas Padhy, David Janz, José Miguel Hernández-Lobato, Alexander Terenin

TL;DR

This paper addresses the computational bottleneck of Gaussian Process posterior sampling by reframing GP conditioning as stochastic optimization, enabling low-cost posterior mean and sample computations via SGD. It introduces a stochastic objective for the posterior mean and a variance-reducing approach for posterior samples, augmented with Random Fourier Features and inducing-point extensions to scale to large datasets. Theoretical analysis reveals an implicit bias where SGD rapidly converges along top kernel-spectrum directions and slowly along low-eigenvalue directions, with a three-region posterior geometry (prior, interpolation, extrapolation). Empirically, SGD-based GP posteriors achieve competitive or state-of-the-art predictive performance on large-scale or ill-conditioned tasks and provide well-calibrated uncertainty for parallel Thompson sampling, offering practical scalability for uncertainty quantification in real-world settings.

Abstract

Gaussian processes are a powerful framework for quantifying uncertainty and for sequential decision-making but are limited by the requirement of solving linear systems. In general, this has a cubic cost in dataset size and is sensitive to conditioning. We explore stochastic gradient algorithms as a computationally efficient method of approximately solving these linear systems: we develop low-variance optimization objectives for sampling from the posterior and extend these to inducing points. Counterintuitively, stochastic gradient descent often produces accurate predictions, even in cases where it does not converge quickly to the optimum. We explain this through a spectral characterization of the implicit bias from non-convergence. We show that stochastic gradient descent produces predictive distributions close to the true posterior both in regions with sufficient data coverage, and in regions sufficiently far away from the data. Experimentally, stochastic gradient descent achieves state-of-the-art performance on sufficiently large-scale or ill-conditioned regression tasks. Its uncertainty estimates match the performance of significantly more expensive baselines on a large-scale Bayesian optimization task.

Sampling from Gaussian Process Posteriors using Stochastic Gradient Descent

TL;DR

This paper addresses the computational bottleneck of Gaussian Process posterior sampling by reframing GP conditioning as stochastic optimization, enabling low-cost posterior mean and sample computations via SGD. It introduces a stochastic objective for the posterior mean and a variance-reducing approach for posterior samples, augmented with Random Fourier Features and inducing-point extensions to scale to large datasets. Theoretical analysis reveals an implicit bias where SGD rapidly converges along top kernel-spectrum directions and slowly along low-eigenvalue directions, with a three-region posterior geometry (prior, interpolation, extrapolation). Empirically, SGD-based GP posteriors achieve competitive or state-of-the-art predictive performance on large-scale or ill-conditioned tasks and provide well-calibrated uncertainty for parallel Thompson sampling, offering practical scalability for uncertainty quantification in real-world settings.

Abstract

Gaussian processes are a powerful framework for quantifying uncertainty and for sequential decision-making but are limited by the requirement of solving linear systems. In general, this has a cubic cost in dataset size and is sensitive to conditioning. We explore stochastic gradient algorithms as a computationally efficient method of approximately solving these linear systems: we develop low-variance optimization objectives for sampling from the posterior and extend these to inducing points. Counterintuitively, stochastic gradient descent often produces accurate predictions, even in cases where it does not converge quickly to the optimum. We explain this through a spectral characterization of the implicit bias from non-convergence. We show that stochastic gradient descent produces predictive distributions close to the true posterior both in regions with sufficient data coverage, and in regions sufficiently far away from the data. Experimentally, stochastic gradient descent achieves state-of-the-art performance on sufficiently large-scale or ill-conditioned regression tasks. Its uncertainty estimates match the performance of significantly more expensive baselines on a large-scale Bayesian optimization task.
Paper Structure (39 sections, 10 theorems, 79 equations, 10 figures, 4 tables)

This paper contains 39 sections, 10 theorems, 79 equations, 10 figures, 4 tables.

Key Result

proposition 1

Let $\delta>0$. Let $\m\Sigma = \sigma^2\m{I}$ for $\sigma^2 > 0$. Let $\mu_{\f{SGD}}$ be the predictive mean obtained by Polyak-averaged SGD after $t$ steps, starting from an initial set of representer weights equal to zero, and using a sufficiently small learning rate of $0 < \eta <\frac{\sigma^2}

Figures (10)

  • Figure 1: Comparison of SGD, CG WangPGT2019exactgp and SVGP hensman13 for GP inference with a squared exponential kernel on $10\text{k}$ datapoints from $\sin(2x)+\cos(5x)$ with observation noise $\text{N}(0, 0.5)$. We draw 2000 function samples with all methods by running them for 10 minutes on an RTX 2070 GPU. Infill asymptotics considers $x_i\~[N](0,1)$. A large number of points near zero result in a very ill-conditioned kernel matrix, preventing CG from converging. SGD converges in all of input space except at the edges of the data. SVGP can summarise the data with only 20 inducing points. Note that CG converges to the exact solution if one uses more compute, but produces significant errors if stopped too early, as occurs under the given compute budget. Large domain asymptotics considers data on a regular grid with fixed spacing. This problem is better conditioned, allowing SGD and CG to recover the exact solution. However, 1024 inducing points are not enough for SVGP to summarize the data.
  • Figure 2: Left: gradient variance throughout optimization for a single-sample minibatch estimator ($D=1$) of \ref{['eqn:samples-optim']}, labeled Loss 1, and the proposed sampling objective \ref{['eqn:samples-optim-reduced']}, labeled Loss 2, on the elevators dataset ($N\approx16\text{k}$). Middle plots: test RMSE and negative log-likelihood (NLL) obtained by SGD and its inducing point variants, for decreasing numbers of inducing points, given in the rightmost plot, as a function of time on an A100 GPU, on the houseelectric dataset ($N\approx2$M).
  • Figure 3: Convergence of GP posterior mean with SGD and CG as a function of time (on an A100 GPU) on the elevators dataset ($N \approx 16\text{k}$) while setting the noise scale to (i) maximize exact GP marginal likelihood and (ii) to $10^{-3}$, labeled low noise. We plot, in left-to-right order, test RMSE, RMSE to the exact GP mean at the test inputs, representer weight error $\norm{\v{v} - \v{v}^{*}}_{2}$, and RKHS error $\norm[0]{\mu_{f\given\v{y}} - \mu_{\f{SGD}}}_{H_k}$. In the latter two plots, the low-noise setting is shown on the bottom.
  • Figure 4: SGD error and spectral basis functions. Top-left: SGD (blue) and exact GP (black, dashed) fit to a $N = 10\text{k}$, $d=1$ toy regression dataset. Top-right: 2-Wasserstein distance (W2) between both processes' marginals. The W2 values are low near the data (interpolation region) and far away from the training data. The error concentrates at the edges of the data (extrapolation region). Bottom: The low-index spectral basis functions lie on the interpolation region, where the W2 error is low, while functions of index $10$ and larger lie on the extrapolation region where the error is large.
  • Figure 5: Test RMSE and NLL as a function of compute time on a TPUv2 core for CG and SGD.
  • ...and 5 more figures

Theorems & Definitions (24)

  • proposition 1
  • lemma 1
  • proof
  • definition 1
  • definition 2
  • definition 3
  • definition 4
  • lemma 2
  • proof
  • lemma 3
  • ...and 14 more