Table of Contents
Fetching ...

Transformers can do Bayesian Clustering

Prajit Bhaskaran, Tom Viering

TL;DR

Cluster-PFN introduces a Transformer-based extension of Prior-Data Fitted Networks for unsupervised Bayesian clustering. Trained entirely on synthetic data from a finite Gaussian Mixture Model prior, it learns to estimate the posterior over the number of clusters $k$ and per-point responsibilities, enabling fast, uncertainty-aware clustering that scales to larger datasets. It outperforms handcrafted criteria (AIC, BIC, Silhouette) and Variational Inference in predicting the number of clusters, while achieving clustering quality competitive with VI and significantly faster inference. The approach demonstrates robustness to missing data via learned priors and shows practical potential for scalable Bayesian clustering in real-world genomic datasets and beyond.

Abstract

Bayesian clustering accounts for uncertainty but is computationally demanding at scale. Furthermore, real-world datasets often contain missing values, and simple imputation ignores the associated uncertainty, resulting in suboptimal results. We present Cluster-PFN, a Transformer-based model that extends Prior-Data Fitted Networks (PFNs) to unsupervised Bayesian clustering. Trained entirely on synthetic datasets generated from a finite Gaussian Mixture Model (GMM) prior, Cluster-PFN learns to estimate the posterior distribution over both the number of clusters and the cluster assignments. Our method estimates the number of clusters more accurately than handcrafted model selection procedures such as AIC, BIC and Variational Inference (VI), and achieves clustering quality competitive with VI while being orders of magnitude faster. Cluster-PFN can be trained on complex priors that include missing data, outperforming imputation-based baselines on real-world genomic datasets, at high missingness. These results show that the Cluster-PFN can provide scalable and flexible Bayesian clustering.

Transformers can do Bayesian Clustering

TL;DR

Cluster-PFN introduces a Transformer-based extension of Prior-Data Fitted Networks for unsupervised Bayesian clustering. Trained entirely on synthetic data from a finite Gaussian Mixture Model prior, it learns to estimate the posterior over the number of clusters and per-point responsibilities, enabling fast, uncertainty-aware clustering that scales to larger datasets. It outperforms handcrafted criteria (AIC, BIC, Silhouette) and Variational Inference in predicting the number of clusters, while achieving clustering quality competitive with VI and significantly faster inference. The approach demonstrates robustness to missing data via learned priors and shows practical potential for scalable Bayesian clustering in real-world genomic datasets and beyond.

Abstract

Bayesian clustering accounts for uncertainty but is computationally demanding at scale. Furthermore, real-world datasets often contain missing values, and simple imputation ignores the associated uncertainty, resulting in suboptimal results. We present Cluster-PFN, a Transformer-based model that extends Prior-Data Fitted Networks (PFNs) to unsupervised Bayesian clustering. Trained entirely on synthetic datasets generated from a finite Gaussian Mixture Model (GMM) prior, Cluster-PFN learns to estimate the posterior distribution over both the number of clusters and the cluster assignments. Our method estimates the number of clusters more accurately than handcrafted model selection procedures such as AIC, BIC and Variational Inference (VI), and achieves clustering quality competitive with VI while being orders of magnitude faster. Cluster-PFN can be trained on complex priors that include missing data, outperforming imputation-based baselines on real-world genomic datasets, at high missingness. These results show that the Cluster-PFN can provide scalable and flexible Bayesian clustering.

Paper Structure

This paper contains 46 sections, 9 equations, 21 figures, 7 tables.

Figures (21)

  • Figure 1: Cluster-PFN usage. In (a), we provide $k=0$ to the Cluster-PFN, which signals the Cluster-PFN it should predict the number of clusters $P(k | X)$. In (b), we provide $k=3$ to the Cluster-PFN, meaning it should group the data into 3 clusters. In this case, it will estimate the probability for each cluster for each object, also called the cluster responsibilities.
  • Figure 2: Full training pipeline of the Cluster-PFN
  • Figure 3: Cluster-PFN prediction on the Old Faithful dataset (Cluster-PFN trained on Easy prior).
  • Figure 4: Cluster-PFN predictions on prior-data for different conditions (Easy prior).
  • Figure 5: External metrics and NLL on 2D Hard datasets with model selection: (a)–(b) show metric distributions for small and large datasets, (c)–(d) show corresponding NLL histograms.
  • ...and 16 more figures