Table of Contents
Fetching ...

Bridge the Inference Gaps of Neural Processes via Expectation Maximization

Qi Wang, Marco Federici, Herke van Hoof

TL;DR

This work addresses the inference suboptimality of vanilla Neural Processes (NPs) by introducing Self-normalized Importance Weighted Neural Processes (SI-NP), a principled EM-based surrogate objective that targets the meta-dataset log-likelihood $\mathcal{L}(\vartheta)$. By leveraging a variational EM framework and self-normalized importance sampling, SI-NP learns a richer functional prior $p(z\vert\mathcal{D}_{\tau}^{C};\vartheta)$ and provides an improvement guarantee over the target likelihood. The approach establishes connections between SI-NPs and CNPs, demonstrates an equivalence under certain limits, and shows competitive performance on synthetic regression and image completion tasks, with attention-based inductive biases further boosting results. Overall, SI-NP offers a principled optimization perspective for learning distributions over functions and provides a scalable path to better uncertainty modeling in neural processes.

Abstract

The neural process (NP) is a family of computationally efficient models for learning distributions over functions. However, it suffers from under-fitting and shows suboptimal performance in practice. Researchers have primarily focused on incorporating diverse structural inductive biases, \textit{e.g.} attention or convolution, in modeling. The topic of inference suboptimality and an analysis of the NP from the optimization objective perspective has hardly been studied in earlier work. To fix this issue, we propose a surrogate objective of the target log-likelihood of the meta dataset within the expectation maximization framework. The resulting model, referred to as the Self-normalized Importance weighted Neural Process (SI-NP), can learn a more accurate functional prior and has an improvement guarantee concerning the target log-likelihood. Experimental results show the competitive performance of SI-NP over other NPs objectives and illustrate that structural inductive biases, such as attention modules, can also augment our method to achieve SOTA performance. Our code is available at \url{https://github.com/hhq123gogogo/SI_NPs}.

Bridge the Inference Gaps of Neural Processes via Expectation Maximization

TL;DR

This work addresses the inference suboptimality of vanilla Neural Processes (NPs) by introducing Self-normalized Importance Weighted Neural Processes (SI-NP), a principled EM-based surrogate objective that targets the meta-dataset log-likelihood . By leveraging a variational EM framework and self-normalized importance sampling, SI-NP learns a richer functional prior and provides an improvement guarantee over the target likelihood. The approach establishes connections between SI-NPs and CNPs, demonstrates an equivalence under certain limits, and shows competitive performance on synthetic regression and image completion tasks, with attention-based inductive biases further boosting results. Overall, SI-NP offers a principled optimization perspective for learning distributions over functions and provides a scalable path to better uncertainty modeling in neural processes.

Abstract

The neural process (NP) is a family of computationally efficient models for learning distributions over functions. However, it suffers from under-fitting and shows suboptimal performance in practice. Researchers have primarily focused on incorporating diverse structural inductive biases, \textit{e.g.} attention or convolution, in modeling. The topic of inference suboptimality and an analysis of the NP from the optimization objective perspective has hardly been studied in earlier work. To fix this issue, we propose a surrogate objective of the target log-likelihood of the meta dataset within the expectation maximization framework. The resulting model, referred to as the Self-normalized Importance weighted Neural Process (SI-NP), can learn a more accurate functional prior and has an improvement guarantee concerning the target log-likelihood. Experimental results show the competitive performance of SI-NP over other NPs objectives and illustrate that structural inductive biases, such as attention modules, can also augment our method to achieve SOTA performance. Our code is available at \url{https://github.com/hhq123gogogo/SI_NPs}.
Paper Structure (51 sections, 4 theorems, 41 equations, 23 figures, 10 tables, 1 algorithm)

This paper contains 51 sections, 4 theorems, 41 equations, 23 figures, 10 tables, 1 algorithm.

Key Result

Proposition 1

The proposed meta learning function $\mathcal{L}(\vartheta;\vartheta_k)$ in Eq. (surrogate_mlmdns) is a surrogate function w.r.t. the log-likelihood of the meta learning dataset.

Figures (23)

  • Figure 1: Deep Latent Variable Models for Neural Processes. Here $\mathcal{D}^{C}$ and $\mathcal{D}^{T}$ respectively denote the context points for the functional prior inference and the target points for the function prediction. The global latent variable $z$ is to summarize function properties. The model involves a functional prior distribution $p(z\vert\mathcal{D}^C;\vartheta)$ and a functional generative distribution $p(\mathcal{D}^T\vert z;\vartheta)$. Please refer to Section (\ref{['prelimi_sec']}) for detailed notation descriptions.
  • Figure 2: Illustration of Expectation Maximization for NPs. Green lines indicate the results after the $\texttt{E}$-steps while the red lines are for the $\texttt{M}$-steps in Algorithm (\ref{['vem_pseudo']}). In the convergence iteration, the performance gap $\mathcal{L}(\vartheta_H)-\mathcal{L}(\vartheta_{H-1})$ is close to zero and the algorithm results in at least a local optimal solution. Values of these quantities are increased from the left to the right.
  • Figure 3: Examples of Curve Fitting in RBF Kernel Cases. The plots report predictive mean functions with $\pm 3$ standard deviations.
  • Figure 4: SI-NP Completed Images. From top to bottom in rows are original images, context points, learned predictive means and variances of sampled images.
  • Figure 5: Asymptotic Performance in Image Completion. We meta test pixel average log-likelihoods with varying number of context points in image datasets. Context points are randomly selected for each image in testing processes. For MNIST/FMNIST datasets, the numbers of context pixels in testing are $\{10, 100, 300, 500, 700\}$. For CIFAR10/SVHN datasets, the numbers of context pixels in testing are $\{10, 100, 300, 500, 800, 1000\}$.
  • ...and 18 more figures

Theorems & Definitions (12)

  • Remark 1
  • Definition 3.1: Prior Collapse
  • Proposition 1
  • Proposition 2
  • Proposition 3
  • Definition B.1: Exchangeable Stochastic Processes
  • Definition D.1: Permutation Invariant Functions
  • Proof D.1: Remark 1
  • Definition E.1: Surrogate Function
  • Theorem 1: L'Hôpital's Rule hospital1696anlyse
  • ...and 2 more