Solving Empirical Bayes via Transformers
Anzo Teh, Mark Jabbour, Yury Polyanskiy
TL;DR
This work shows that encoder-only transformers pretrained on synthetic Poisson-EB data can perform in-context learning to estimate high-dimensional Poisson means with unknown priors. The authors prove universal-approximation-type results for EB estimators (Robbins and NPMLE) and demonstrate that small models can outperform NPMLE in both runtime and validation loss on synthetic and real data, including NHL, MLB, and BookCorpusOpen, while exhibiting length generalization. They further reveal through linear probes that the transformer estimators implement Bayes-like shrinkage behavior distinct from classical EB estimators. Practically, this approach offers a fast, scalable alternative for empirical Bayes inference in large-scale problems, with potential extensions to multi-dimensional settings and posterior uncertainty quantification.
Abstract
This work applies modern AI tools (transformers) to solving one of the oldest statistical problems: Poisson means under empirical Bayes (Poisson-EB) setting. In Poisson-EB a high-dimensional mean vector $θ$ (with iid coordinates sampled from an unknown prior $π$) is estimated on the basis of $X=\mathrm{Poisson}(θ)$. A transformer model is pre-trained on a set of synthetically generated pairs $(X,θ)$ and learns to do in-context learning (ICL) by adapting to unknown $π$. Theoretically, we show that a sufficiently wide transformer can achieve vanishing regret with respect to an oracle estimator who knows $π$ as dimension grows to infinity. Practically, we discover that already very small models (100k parameters) are able to outperform the best classical algorithm (non-parametric maximum likelihood, or NPMLE) both in runtime and validation loss, which we compute on out-of-distribution synthetic data as well as real-world datasets (NHL hockey, MLB baseball, BookCorpusOpen). Finally, by using linear probes, we confirm that the transformer's EB estimator appears to internally work differently from either NPMLE or Robbins' estimators.
