Table of Contents
Fetching ...

Turbocharging Gaussian Process Inference with Approximate Sketch-and-Project

Pratik Rathore, Zachary Frangella, Sachin Garg, Shaghayegh Fazliani, Michał Dereziński, Madeleine Udell

TL;DR

This work tackles the scalability and conditioning challenges of Gaussian process inference by introducing ADASAP, an approximate, distributed, accelerated sketch-and-project method. By combining Nyström-based approximate subspace preconditioning, distributed matrix operations, and Nesterov acceleration within the sketch-and-project framework, ADASAP achieves fast convergence along the top spectral directions while maintaining scalability to datasets with hundreds of millions of points. Theoretical guarantees show condition-number-free progress in the early phase along dominant RKHS directions, complemented by a linear-rate regime in later phases, with empirical results demonstrating superior RMSE, NLL, and wall-clock efficiency on large-scale GP tasks and a billion-point transportation dataset, as well as strong performance in Bayesian optimization. The approach offers practical defaults, scales across GPUs, and provides a principled alternative to PCG and SDD for large-scale GP inference with robust performance in ill-conditioned settings.

Abstract

Gaussian processes (GPs) play an essential role in biostatistics, scientific machine learning, and Bayesian optimization for their ability to provide probabilistic predictions and model uncertainty. However, GP inference struggles to scale to large datasets (which are common in modern applications), since it requires the solution of a linear system whose size scales quadratically with the number of samples in the dataset. We propose an approximate, distributed, accelerated sketch-and-project algorithm ($\texttt{ADASAP}$) for solving these linear systems, which improves scalability. We use the theory of determinantal point processes to show that the posterior mean induced by sketch-and-project rapidly converges to the true posterior mean. In particular, this yields the first efficient, condition number-free algorithm for estimating the posterior mean along the top spectral basis functions, showing that our approach is principled for GP inference. $\texttt{ADASAP}$ outperforms state-of-the-art solvers based on conjugate gradient and coordinate descent across several benchmark datasets and a large-scale Bayesian optimization task. Moreover, $\texttt{ADASAP}$ scales to a dataset with $> 3 \cdot 10^8$ samples, a feat which has not been accomplished in the literature.

Turbocharging Gaussian Process Inference with Approximate Sketch-and-Project

TL;DR

This work tackles the scalability and conditioning challenges of Gaussian process inference by introducing ADASAP, an approximate, distributed, accelerated sketch-and-project method. By combining Nyström-based approximate subspace preconditioning, distributed matrix operations, and Nesterov acceleration within the sketch-and-project framework, ADASAP achieves fast convergence along the top spectral directions while maintaining scalability to datasets with hundreds of millions of points. Theoretical guarantees show condition-number-free progress in the early phase along dominant RKHS directions, complemented by a linear-rate regime in later phases, with empirical results demonstrating superior RMSE, NLL, and wall-clock efficiency on large-scale GP tasks and a billion-point transportation dataset, as well as strong performance in Bayesian optimization. The approach offers practical defaults, scales across GPUs, and provides a principled alternative to PCG and SDD for large-scale GP inference with robust performance in ill-conditioned settings.

Abstract

Gaussian processes (GPs) play an essential role in biostatistics, scientific machine learning, and Bayesian optimization for their ability to provide probabilistic predictions and model uncertainty. However, GP inference struggles to scale to large datasets (which are common in modern applications), since it requires the solution of a linear system whose size scales quadratically with the number of samples in the dataset. We propose an approximate, distributed, accelerated sketch-and-project algorithm () for solving these linear systems, which improves scalability. We use the theory of determinantal point processes to show that the posterior mean induced by sketch-and-project rapidly converges to the true posterior mean. In particular, this yields the first efficient, condition number-free algorithm for estimating the posterior mean along the top spectral basis functions, showing that our approach is principled for GP inference. outperforms state-of-the-art solvers based on conjugate gradient and coordinate descent across several benchmark datasets and a large-scale Bayesian optimization task. Moreover, scales to a dataset with samples, a feat which has not been accomplished in the literature.

Paper Structure

This paper contains 49 sections, 13 theorems, 62 equations, 7 figures, 2 tables, 7 algorithms.

Key Result

Theorem 3.2

Suppose we have a kernel matrix $K \in \mathbb R^{n \times n}$, observations $y \in \mathbb R^n$ from a GP prior with likelihood variance $\lambda \geq 0$, and let $\phi(\cdot,\cdot)$ denote the smoothed condition number of $K + \lambda I$. Given any $\ell\in[n]$, let $\mathrm{proj}_\ell$ denote ort

Figures (7)

  • Figure 1: ADASAP attains lower root mean square error (RMSE) and mean negative log likelihood (NLL) than start-of-the-art methods SDD lin2024stochastic and PCG on the houseelec dataset. SDD-1 and SDD-10 correspond to two particular stepsize selections for the SDD method. The solid lines indicate the mean performance of each method, while the shaded regions indicate the range between worst and best performance of each method over five random splits of the data.
  • Figure 2: Performance of ADASAP and competitors on RMSE and mean NLL, as a function of time, for benzene, malonaldehyde, and houseelec. The solid curve indicates mean performance over random splits of the data; the shaded regions indicate the range between the worst and best performance over random splits of the data. ADASAP outperforms the competition.
  • Figure 3: Comparison between ADASAP and competitors on transportation data analysis. ADASAP attains the lowest RMSE and it obtains a $1.8\times$ speed up over the second-best method, SDD-10. SDD-100 diverges and PCG runs out of memory, so they do not appear in the figure.
  • Figure 4: Multi-GPU scaling of ADASAP on the taxi dataset. ADASAP obtains near-linear scaling with the number of GPUs.
  • Figure 5: Performance of ADASAP with and without tail averaging. One "data pass" corresponds to one pass through the kernel matrix. Tail averaging does not improve convergence by a substantial margin.
  • ...and 2 more figures

Theorems & Definitions (23)

  • Definition 3.1: Smoothed condition number
  • Theorem 3.2
  • Corollary 3.3
  • Remark 3.4
  • Lemma B.1: Expected projection under $2b$-DPP($A$), adapted from Lemma 4.1 derezinski2024solving
  • Lemma B.2: Sampling from $b$-DPP($A$), adapted from anari2024optimal
  • Theorem B.3: Fast convergence along top-$\ell$ subspace
  • Lemma B.4: Linear convergence with SAP
  • Lemma B.5: Convergence of expected iterates along top-$\ell$ subspace
  • proof
  • ...and 13 more