Table of Contents
Fetching ...

Attentive Neural Processes

Hyunjik Kim, Andriy Mnih, Jonathan Schwarz, Marta Garnelo, Ali Eslami, Dan Rosenbaum, Oriol Vinyals, Yee Whye Teh

TL;DR

The paper tackles underfitting in Neural Processes by introducing Attentive Neural Processes that leverage differentiable attention to let each target selectively attend to relevant context points. By inserting self-attention on context representations and cross-attention from targets to these representations, ANPs achieve more accurate context reconstructions, faster training, and greater expressiveness, at the cost of higher computational complexity. Experiments on 1D GP-like regression and 2D image regression (MNIST/CelebA) demonstrate improved predictive distributions and sharper reconstructions, with additional benefits in resolution mapping and simple Bayesian optimization. The work positions ANPs as a scalable, flexible approach for learning conditional distributions over regression functions with both local and global structure, bridging ideas from Gaussian Processes and meta-learning with modern attention mechanisms.

Abstract

Neural Processes (NPs) (Garnelo et al 2018a;b) approach regression by learning to map a context set of observed input-output pairs to a distribution over regression functions. Each function models the distribution of the output given an input, conditioned on the context. NPs have the benefit of fitting observed data efficiently with linear complexity in the number of context input-output pairs, and can learn a wide family of conditional distributions; they learn predictive distributions conditioned on context sets of arbitrary size. Nonetheless, we show that NPs suffer a fundamental drawback of underfitting, giving inaccurate predictions at the inputs of the observed data they condition on. We address this issue by incorporating attention into NPs, allowing each input location to attend to the relevant context points for the prediction. We show that this greatly improves the accuracy of predictions, results in noticeably faster training, and expands the range of functions that can be modelled.

Attentive Neural Processes

TL;DR

The paper tackles underfitting in Neural Processes by introducing Attentive Neural Processes that leverage differentiable attention to let each target selectively attend to relevant context points. By inserting self-attention on context representations and cross-attention from targets to these representations, ANPs achieve more accurate context reconstructions, faster training, and greater expressiveness, at the cost of higher computational complexity. Experiments on 1D GP-like regression and 2D image regression (MNIST/CelebA) demonstrate improved predictive distributions and sharper reconstructions, with additional benefits in resolution mapping and simple Bayesian optimization. The work positions ANPs as a scalable, flexible approach for learning conditional distributions over regression functions with both local and global structure, bridging ideas from Gaussian Processes and meta-learning with modern attention mechanisms.

Abstract

Neural Processes (NPs) (Garnelo et al 2018a;b) approach regression by learning to map a context set of observed input-output pairs to a distribution over regression functions. Each function models the distribution of the output given an input, conditioned on the context. NPs have the benefit of fitting observed data efficiently with linear complexity in the number of context input-output pairs, and can learn a wide family of conditional distributions; they learn predictive distributions conditioned on context sets of arbitrary size. Nonetheless, we show that NPs suffer a fundamental drawback of underfitting, giving inaccurate predictions at the inputs of the observed data they condition on. We address this issue by incorporating attention into NPs, allowing each input location to attend to the relevant context points for the prediction. We show that this greatly improves the accuracy of predictions, results in noticeably faster training, and expands the range of functions that can be modelled.

Paper Structure

This paper contains 13 sections, 6 equations, 19 figures.

Figures (19)

  • Figure 1: Comparison of predictions given by a fully trained NP and Attentive NP (ANP) in 1D function regression (left) / 2D image regression (right). The contexts (crosses/top half pixels) are used to predict the target outputs ($y$-values of all $x \in [-2,2]$/all pixels in image). The ANP predictions are noticeably more accurate than for NP at the context points.
  • Figure 2: Model architecture for the NP (left) and Attentive NP (right)
  • Figure 3: Qualitative and quantitative results of different attention mechanisms for 1D GP function regression with random kernel hyperparameters. Left: moving average of context reconstruction error (top) and target negative log likelihood (NLL) given contexts (bottom) plotted against training iterations (left) and wall clock time (right). $d$ denotes the bottleneck size i.e. hidden layer size of all MLPs and the dimensionality of $r$ and $z$. Right: predictive mean and variance of different attention mechanisms given the same context. Best viewed in colour.
  • Figure 4: Qualitative and quantitative results on test set for 2D CelebA function regression.
  • Figure 5: Reconstruction of full image from top half. The CelebA results use the same models (with the same parameter values) as Figure \ref{['fig:celeba_random']}.
  • ...and 14 more figures