Table of Contents
Fetching ...

Scaling Gaussian Processes for Learning Curve Prediction via Latent Kronecker Structure

Jihao Andreas Lin, Sebastian Ament, Maximilian Balandat, Eytan Bakshy

TL;DR

The GP model can match the performance of a Transformer on a learning curve prediction task and interpret the joint covariance matrix of observed values as the projection of a latent Kronecker product.

Abstract

A key task in AutoML is to model learning curves of machine learning models jointly as a function of model hyper-parameters and training progression. While Gaussian processes (GPs) are suitable for this task, naïve GPs require $\mathcal{O}(n^3m^3)$ time and $\mathcal{O}(n^2 m^2)$ space for $n$ hyper-parameter configurations and $\mathcal{O}(m)$ learning curve observations per hyper-parameter. Efficient inference via Kronecker structure is typically incompatible with early-stopping due to missing learning curve values. We impose $\textit{latent Kronecker structure}$ to leverage efficient product kernels while handling missing values. In particular, we interpret the joint covariance matrix of observed values as the projection of a latent Kronecker product. Combined with iterative linear solvers and structured matrix-vector multiplication, our method only requires $\mathcal{O}(n^3 + m^3)$ time and $\mathcal{O}(n^2 + m^2)$ space. We show that our GP model can match the performance of a Transformer on a learning curve prediction task.

Scaling Gaussian Processes for Learning Curve Prediction via Latent Kronecker Structure

TL;DR

The GP model can match the performance of a Transformer on a learning curve prediction task and interpret the joint covariance matrix of observed values as the projection of a latent Kronecker product.

Abstract

A key task in AutoML is to model learning curves of machine learning models jointly as a function of model hyper-parameters and training progression. While Gaussian processes (GPs) are suitable for this task, naïve GPs require time and space for hyper-parameter configurations and learning curve observations per hyper-parameter. Efficient inference via Kronecker structure is typically incompatible with early-stopping due to missing learning curve values. We impose to leverage efficient product kernels while handling missing values. In particular, we interpret the joint covariance matrix of observed values as the projection of a latent Kronecker product. Combined with iterative linear solvers and structured matrix-vector multiplication, our method only requires time and space. We show that our GP model can match the performance of a Transformer on a learning curve prediction task.

Paper Structure

This paper contains 17 sections, 5 equations, 4 figures.

Figures (4)

  • Figure 1: Learning curve predictions on the Fashion-MNIST data from the LCBench benchmark ZimLin2021a. The GP is fit to 16 partially observed learning curves (black). Their ground truth continuations (orange) are contained within the spread of posterior samples (blue). A typical learning curve which is observed close to convergence is predicted with confidence (left). Observing a smaller fraction of the learning curve leads to increased uncertainty in the prediction (left middle). The model also adapts well to less common noisy and spiky learning curves (right).
  • Figure 2: Selecting the joint covariance matrix using projections of the latent Kronecker product after observing $\{ \textcolor{tab:blue}{({\mathbf x}_1, t_1)}, \textcolor{tab:blue}{({\mathbf x}_1, t_2)}, \textcolor{tab:orange}{({\mathbf x}_2, t_1)}, \textcolor{tab:orange}{({\mathbf x}_2, t_2)}, \textcolor{tab:orange}{({\mathbf x}_2, t_3)} \}$, two learning curve values from a first hyper-parameter configuration (blue) and three values from a second configuration (orange).
  • Figure 3: Time and memory consumption as a function of training data size, where size refers to $n = m$. Training consists of optimizing noise $\sigma^2$ and kernel parameters ${\boldsymbol \theta}$. Prediction consists of sampling full learning curves for 512 hyper-parameter configurations. Measurements include constant overheads, such as computations performed by the optimizer or memory reserved by CUDA drivers.
  • Figure 4: Mean-square-errors (MSE) and log-likelihoods (LLH) of predicted final validation accuracy given partially observed learning curves (mean $\pm$ standard error over 100 random seeds), where # of training examples refers to the total number of observed values across hyper-parameters and progression. Log-likelihood values for DPL are omitted because they are not competitive enough.