Estimating Probability Densities with Transformer and Denoising Diffusion
Henry W. Leung, Jo Bovy, Joshua S. Speagle
TL;DR
This work tackles the limitation of scalar predictions in scientific regression by introducing an encoder-only Transformer equipped with a denoising diffusion probabilistic model head to estimate conditional probability densities. The model can generate samples and densities conditioned on arbitrary input combinations, enabling non-Gaussian and multimodal outputs. Demonstrations on Galactic stellar data and a California Housing dataset show that the method recovers training densities, produces sensible conditional densities, and even constructs multi-dimensional distributions through sequential conditioning, offering a flexible and scalable density emulator for scientific foundation models. This approach enhances uncertainty quantification and applicability of large-scale foundation models to complex, high-dimensional scientific inference tasks.
Abstract
Transformers are often the go-to architecture to build foundation models that ingest a large amount of training data. But these models do not estimate the probability density distribution when trained on regression problems, yet obtaining full probabilistic outputs is crucial to many fields of science, where the probability distribution of the answer can be non-Gaussian and multimodal. In this work, we demonstrate that training a probabilistic model using a denoising diffusion head on top of the Transformer provides reasonable probability density estimation even for high-dimensional inputs. The combined Transformer+Denoising Diffusion model allows conditioning the output probability density on arbitrary combinations of inputs and it is thus a highly flexible density function emulator of all possible input/output combinations. We illustrate our Transformer+Denoising Diffusion model by training it on a large dataset of astronomical observations and measured labels of stars within our Galaxy and we apply it to a variety of inference tasks to show that the model can infer labels accurately with reasonable distributions.
