Table of Contents
Fetching ...

Distribution Transformers: Fast Approximate Bayesian Inference With On-The-Fly Prior Adaptation

George Whittle, Juliusz Ziomek, Jacob Rawling, Michael A Osborne

TL;DR

Bayesian inference faces intractable posteriors and limited prior flexibility in many real-time settings. Distribution Transformers (DTs) address this by representing priors and posteriors as Gaussian Mixture Models and learning a transformer-based mapping from priors to posteriors conditioned on observations, enabling on-the-fly prior adaptation with approximate conjugacy. DTs support prior amortization across a family of priors and preserve a structure that facilitates sequential updates in filtering contexts. Empirical results across Gaussian Processes with hyperpriors, quantum-system parameter inference, and sequential sensor fusion show that DTs achieve competitive or superior log-likelihood performance while delivering substantial speedups over SVI, PFN, TabPFN, and ACE, highlighting their practicality for real-time uncertainty quantification with flexible priors.

Abstract

While Bayesian inference provides a principled framework for reasoning under uncertainty, its widespread adoption is limited by the intractability of exact posterior computation, necessitating the use of approximate inference. However, existing methods are often computationally expensive, or demand costly retraining when priors change, limiting their utility, particularly in sequential inference problems such as real-time sensor fusion. To address these challenges, we introduce the Distribution Transformer -- a novel architecture that can learn arbitrary distribution-to-distribution mappings. Our method can be trained to map a prior to the corresponding posterior, conditioned on some dataset -- thus performing approximate Bayesian inference. Our novel architecture represents a prior distribution as a (universally-approximating) Gaussian Mixture Model (GMM), and transforms it into a GMM representation of the posterior. The components of the GMM attend to each other via self-attention, and to the datapoints via cross-attention. We demonstrate that Distribution Transformers both maintain flexibility to vary the prior, and significantly reduces computation times-from minutes to milliseconds-while achieving log-likelihood performance on par with or superior to existing approximate inference methods across tasks such as sequential inference, quantum system parameter inference, and Gaussian Process predictive posterior inference with hyperpriors.

Distribution Transformers: Fast Approximate Bayesian Inference With On-The-Fly Prior Adaptation

TL;DR

Bayesian inference faces intractable posteriors and limited prior flexibility in many real-time settings. Distribution Transformers (DTs) address this by representing priors and posteriors as Gaussian Mixture Models and learning a transformer-based mapping from priors to posteriors conditioned on observations, enabling on-the-fly prior adaptation with approximate conjugacy. DTs support prior amortization across a family of priors and preserve a structure that facilitates sequential updates in filtering contexts. Empirical results across Gaussian Processes with hyperpriors, quantum-system parameter inference, and sequential sensor fusion show that DTs achieve competitive or superior log-likelihood performance while delivering substantial speedups over SVI, PFN, TabPFN, and ACE, highlighting their practicality for real-time uncertainty quantification with flexible priors.

Abstract

While Bayesian inference provides a principled framework for reasoning under uncertainty, its widespread adoption is limited by the intractability of exact posterior computation, necessitating the use of approximate inference. However, existing methods are often computationally expensive, or demand costly retraining when priors change, limiting their utility, particularly in sequential inference problems such as real-time sensor fusion. To address these challenges, we introduce the Distribution Transformer -- a novel architecture that can learn arbitrary distribution-to-distribution mappings. Our method can be trained to map a prior to the corresponding posterior, conditioned on some dataset -- thus performing approximate Bayesian inference. Our novel architecture represents a prior distribution as a (universally-approximating) Gaussian Mixture Model (GMM), and transforms it into a GMM representation of the posterior. The components of the GMM attend to each other via self-attention, and to the datapoints via cross-attention. We demonstrate that Distribution Transformers both maintain flexibility to vary the prior, and significantly reduces computation times-from minutes to milliseconds-while achieving log-likelihood performance on par with or superior to existing approximate inference methods across tasks such as sequential inference, quantum system parameter inference, and Gaussian Process predictive posterior inference with hyperpriors.

Paper Structure

This paper contains 27 sections, 1 theorem, 5 equations, 7 figures, 4 tables, 1 algorithm.

Key Result

Proposition 3.1

The proposed loss $l_\theta$ is equal to the expected KL-Divergence $\mathbb{E}_{p(\phi,z)}\left[\text{KL}\left[p\mathop{||}q_{\theta}\right]\right]$ between $p(\cdot \mathop{|} z,\phi)$ and $q_{\theta}(\cdot \mathop{|} z,\phi)$ up to an additive constant. Proof in Appendix proof:loss_function_1.

Figures (7)

  • Figure 1: Various log-warped GMM approximation to an inverse-gamma prior distributions. Note that even with only five GMM components, the approximation is visually almost indistinguishable from the target distribution. This is true for many frequently encountered distributions in Bayesian inference.
  • Figure 2: Architecture diagram for a distribution transformer. Observations, e.g. from a dataset or a sensor measurement, are transformed to a set of tokens in the latent space via a distinct learnable embedding for each datasource. Priors are represented as a set of embedded GMM components in the latent space via a learnable embedding acting on their parameters. The distribution transformer itself, a transformer decoder, learns to map the prior to the posterior in the latent space, incorporating information from the embedded observations via cross attention. A learnable unembedding acts token-wise on both the prior and posterior latent GMM representations to give a GMM approximation for both the prior and posterior distributions, with which we estimate our loss function $\ell_\theta'.$
  • Figure 3: Ground truth, PFN, and 2 and 5 component DT posterior densities for an inverse-gamma prior with an (a) narrow and (b) wide meta-prior. Both variants of the DT fit the true posterior well, in both cases with the 5 component DT almost indistinguishable from the ground truth. The PFN's shape is correct in both cases, and fits the ground truth correctly (up to the limits of the Riemann distribution) for the narrow meta-prior, but as expected completely fails to fit the ground truth for the wide meta-prior, given the lack of prior. In any case, for a given number of model outputs the DT provides a much tighter fit to the ground distribution.
  • Figure 4: Example plot for the 1-dimensional input GP predictive experiment with hyperpriors, using 10 GMM components. Here we put an InverseGamma(1,2) prior on the lengthscale. We plot our model's predictive posterior in blue and PFNs in green. In orange, we show the oracle that is the exact GP fit with the true lengthscale value (which is unobserved for the other methods). We see that PFNs overestimate the confidence intervals due to the fact that they do not take the prior, particularly that of the lengthscale, into account.The Riemann distribution of the PFN uses 30 buckets, matching the number of outputs of the DT.
  • Figure 5: Expected negative log-likelihood against batch inference time. Note that even given orders of magnitude more computation time, SVI cannot match the performance of our method, again demonstrating the power of our GMM approximation. Note that this problem is particularly challenging for VI, as the likelihood must be marginalized with respect to the uncertainty in the initial state and measurement time, which is not tractable and must be estimated stochastically, increasing the time per iteration.
  • ...and 2 more figures

Theorems & Definitions (3)

  • Proposition 3.1
  • proof
  • proof