Convolutional Deep Kernel Machines
Edward Milsom, Ben Anson, Laurence Aitchison
TL;DR
This work extends the deep kernel machine framework to convolutional settings by introducing convolutional DKMs that preserve representation learning in an infinite-width limit. A novel inter-domain inducing-point scheme enables scalable training on CIFAR-like data, with a CNN-inspired kernel constructed from a convolutional BNN layer and a learnable patch-based inducing mechanism. Empirical results show state-of-the-art performance for kernel methods on MNIST and CIFAR datasets (e.g., ~99% MNIST, ~92.7% CIFAR-10, ~72% CIFAR-100) at ~77 GPU-hours, highlighting the practical potential of DKMs for scalable, representation-learning kernels. The work also maps a path forward for improving kernel-method performance toward neural networks, while noting current limitations on very large or high-resolution datasets and suggesting directions like precision tricks and multi-GPU setups.
Abstract
Standard infinite-width limits of neural networks sacrifice the ability for intermediate layers to learn representations from data. Recent work (A theory of representation learning gives a deep generalisation of kernel methods, Yang et al. 2023) modified the Neural Network Gaussian Process (NNGP) limit of Bayesian neural networks so that representation learning is retained. Furthermore, they found that applying this modified limit to a deep Gaussian process gives a practical learning algorithm which they dubbed the deep kernel machine (DKM). However, they only considered the simplest possible setting: regression in small, fully connected networks with e.g. 10 input features. Here, we introduce convolutional deep kernel machines. This required us to develop a novel inter-domain inducing point approximation, as well as introducing and experimentally assessing a number of techniques not previously seen in DKMs, including analogues to batch normalisation, different likelihoods, and different types of top-layer. The resulting model trains in roughly 77 GPU hours, achieving around 99% test accuracy on MNIST, 72% on CIFAR-100, and 92.7% on CIFAR-10, which is SOTA for kernel methods.
