Transformers can do Bayesian Clustering
Prajit Bhaskaran, Tom Viering
TL;DR
Cluster-PFN introduces a Transformer-based extension of Prior-Data Fitted Networks for unsupervised Bayesian clustering. Trained entirely on synthetic data from a finite Gaussian Mixture Model prior, it learns to estimate the posterior over the number of clusters $k$ and per-point responsibilities, enabling fast, uncertainty-aware clustering that scales to larger datasets. It outperforms handcrafted criteria (AIC, BIC, Silhouette) and Variational Inference in predicting the number of clusters, while achieving clustering quality competitive with VI and significantly faster inference. The approach demonstrates robustness to missing data via learned priors and shows practical potential for scalable Bayesian clustering in real-world genomic datasets and beyond.
Abstract
Bayesian clustering accounts for uncertainty but is computationally demanding at scale. Furthermore, real-world datasets often contain missing values, and simple imputation ignores the associated uncertainty, resulting in suboptimal results. We present Cluster-PFN, a Transformer-based model that extends Prior-Data Fitted Networks (PFNs) to unsupervised Bayesian clustering. Trained entirely on synthetic datasets generated from a finite Gaussian Mixture Model (GMM) prior, Cluster-PFN learns to estimate the posterior distribution over both the number of clusters and the cluster assignments. Our method estimates the number of clusters more accurately than handcrafted model selection procedures such as AIC, BIC and Variational Inference (VI), and achieves clustering quality competitive with VI while being orders of magnitude faster. Cluster-PFN can be trained on complex priors that include missing data, outperforming imputation-based baselines on real-world genomic datasets, at high missingness. These results show that the Cluster-PFN can provide scalable and flexible Bayesian clustering.
