Table of Contents
Fetching ...

Amortized Probabilistic Conditioning for Optimization, Simulation and Inference

Paul E. Chang, Nasrulloh Loka, Daolang Huang, Ulpu Remes, Samuel Kaski, Luigi Acerbi

TL;DR

ACE presents a unified transformer-based framework that amortizes probabilistic conditioning on both observed data and interpretable latent variables, while allowing runtime specification of priors over these latents. By extending the prediction-map TPM-D formalism to explicitly encode latents and priors, ACE delivers closed-form predictive distributions for data and latents via flexible Gaussian mixture or categorical heads. The model demonstrates strong performance across image completion, Bayesian optimization, and simulation-based inference, often matching or surpassing dedicated baselines and offering substantial runtime advantages. This approach provides a practical, generalizable tool for fast, principled conditioning and prediction in a broad range of probabilistic tasks, with potential for scalable, multi-task extensions and latent discovery.

Abstract

Amortized meta-learning methods based on pre-training have propelled fields like natural language processing and vision. Transformer-based neural processes and their variants are leading models for probabilistic meta-learning with a tractable objective. Often trained on synthetic data, these models implicitly capture essential latent information in the data-generation process. However, existing methods do not allow users to flexibly inject (condition on) and extract (predict) this probabilistic latent information at runtime, which is key to many tasks. We introduce the Amortized Conditioning Engine (ACE), a new transformer-based meta-learning model that explicitly represents latent variables of interest. ACE affords conditioning on both observed data and interpretable latent variables, the inclusion of priors at runtime, and outputs predictive distributions for discrete and continuous data and latents. We show ACE's modeling flexibility and performance in diverse tasks such as image completion and classification, Bayesian optimization, and simulation-based inference.

Amortized Probabilistic Conditioning for Optimization, Simulation and Inference

TL;DR

ACE presents a unified transformer-based framework that amortizes probabilistic conditioning on both observed data and interpretable latent variables, while allowing runtime specification of priors over these latents. By extending the prediction-map TPM-D formalism to explicitly encode latents and priors, ACE delivers closed-form predictive distributions for data and latents via flexible Gaussian mixture or categorical heads. The model demonstrates strong performance across image completion, Bayesian optimization, and simulation-based inference, often matching or surpassing dedicated baselines and offering substantial runtime advantages. This approach provides a practical, generalizable tool for fast, principled conditioning and prediction in a broad range of probabilistic tasks, with potential for scalable, multi-task extensions and latent discovery.

Abstract

Amortized meta-learning methods based on pre-training have propelled fields like natural language processing and vision. Transformer-based neural processes and their variants are leading models for probabilistic meta-learning with a tractable objective. Often trained on synthetic data, these models implicitly capture essential latent information in the data-generation process. However, existing methods do not allow users to flexibly inject (condition on) and extract (predict) this probabilistic latent information at runtime, which is key to many tasks. We introduce the Amortized Conditioning Engine (ACE), a new transformer-based meta-learning model that explicitly represents latent variables of interest. ACE affords conditioning on both observed data and interpretable latent variables, the inclusion of priors at runtime, and outputs predictive distributions for discrete and continuous data and latents. We show ACE's modeling flexibility and performance in diverse tasks such as image completion and classification, Bayesian optimization, and simulation-based inference.

Paper Structure

This paper contains 76 sections, 15 equations, 26 figures, 6 tables, 2 algorithms.

Figures (26)

  • Figure 1: Probabilistic conditioning and prediction. Many tasks reduce to probabilistic conditioning on data and key latent variables (left) and then predicting data and latents (right). (a) Image completion and classification (data: pixels; latents: classes). Top: Class prediction. Bottom: Conditional generation. (b) Bayesian optimization (data: function values; latents: optimum location $x_\text{opt}$ and value $y_\text{opt}$). We predict both the function values and $x_\text{opt}$, $y_\text{opt}$ given function observations and a prior over $x_\text{opt}$, $y_\text{opt}$ (here flat). (c) Simulator-based inference (data: observations; latents: model parameter $\theta$). Given data and a prior over $\theta$, we can compute both the posterior over $\theta$ and predictive distribution over unseen data. Our method fully amortizes probabilistic conditioning and prediction.
  • Figure 2: Prior amortization. Two example posterior distributions for the mean $\mu$ and standard deviation $\sigma$ of a 1D Gaussian. (\ref{['fig:prior_distribution']}) Prior distribution over ${\bm{\theta}} = (\mu, \sigma)$ set at runtime. (\ref{['fig:likelihood']}) Likelihood for the observed data. (\ref{['fig:true_posterior']}) Ground-truth Bayesian posterior. (\ref{['fig:ace_predicted_posterior']}) ACE's predicted posterior approximates well the true posterior.
  • Figure 3: Image completion. Image (\ref{['fig:image_full']}) serves as the reference for the problem, where $10\%$ of the pixels are observed (\ref{['fig:context_celeb']}). Figures (\ref{['fig:tnpd_celeb']}) through (\ref{['fig:ACE_theta_celeb']}) display different models' prediction conditioned on the observed pixels (\ref{['fig:context_celeb']}). In addition, (\ref{['fig:ACE_theta_celeb']}) incorporates latent variable ${\bm{\theta}}$ information for the ACE model. Figure (\ref{['fig:celeba-nlpd_main']}) illustrates the different models' performance across varying levels of context.
  • Figure 4: Bayesian Optimization example. (\ref{['fig:BO_walkthrough_1']}) ACE predicts function values ($p(y| x, \mathcal{D}_N)$) as well as latents: optimum location ($p(x_\text{opt} | \mathcal{D}_N)$) and optimum value ($p(y_\text{opt} | \mathcal{D}_N)$). (\ref{['fig:BO_walkthrough_2']}) Further conditioning on $y_\text{opt}$ (here the true minimum value) leads to updated predictions.
  • Figure 5: Bayesian optimization results. Regret comparison (mean ± standard error) for different methods across benchmark tasks.
  • ...and 21 more figures