Amortized Probabilistic Conditioning for Optimization, Simulation and Inference
Paul E. Chang, Nasrulloh Loka, Daolang Huang, Ulpu Remes, Samuel Kaski, Luigi Acerbi
TL;DR
ACE presents a unified transformer-based framework that amortizes probabilistic conditioning on both observed data and interpretable latent variables, while allowing runtime specification of priors over these latents. By extending the prediction-map TPM-D formalism to explicitly encode latents and priors, ACE delivers closed-form predictive distributions for data and latents via flexible Gaussian mixture or categorical heads. The model demonstrates strong performance across image completion, Bayesian optimization, and simulation-based inference, often matching or surpassing dedicated baselines and offering substantial runtime advantages. This approach provides a practical, generalizable tool for fast, principled conditioning and prediction in a broad range of probabilistic tasks, with potential for scalable, multi-task extensions and latent discovery.
Abstract
Amortized meta-learning methods based on pre-training have propelled fields like natural language processing and vision. Transformer-based neural processes and their variants are leading models for probabilistic meta-learning with a tractable objective. Often trained on synthetic data, these models implicitly capture essential latent information in the data-generation process. However, existing methods do not allow users to flexibly inject (condition on) and extract (predict) this probabilistic latent information at runtime, which is key to many tasks. We introduce the Amortized Conditioning Engine (ACE), a new transformer-based meta-learning model that explicitly represents latent variables of interest. ACE affords conditioning on both observed data and interpretable latent variables, the inclusion of priors at runtime, and outputs predictive distributions for discrete and continuous data and latents. We show ACE's modeling flexibility and performance in diverse tasks such as image completion and classification, Bayesian optimization, and simulation-based inference.
