Table of Contents
Fetching ...

Solving Empirical Bayes via Transformers

Anzo Teh, Mark Jabbour, Yury Polyanskiy

TL;DR

This work shows that encoder-only transformers pretrained on synthetic Poisson-EB data can perform in-context learning to estimate high-dimensional Poisson means with unknown priors. The authors prove universal-approximation-type results for EB estimators (Robbins and NPMLE) and demonstrate that small models can outperform NPMLE in both runtime and validation loss on synthetic and real data, including NHL, MLB, and BookCorpusOpen, while exhibiting length generalization. They further reveal through linear probes that the transformer estimators implement Bayes-like shrinkage behavior distinct from classical EB estimators. Practically, this approach offers a fast, scalable alternative for empirical Bayes inference in large-scale problems, with potential extensions to multi-dimensional settings and posterior uncertainty quantification.

Abstract

This work applies modern AI tools (transformers) to solving one of the oldest statistical problems: Poisson means under empirical Bayes (Poisson-EB) setting. In Poisson-EB a high-dimensional mean vector $θ$ (with iid coordinates sampled from an unknown prior $π$) is estimated on the basis of $X=\mathrm{Poisson}(θ)$. A transformer model is pre-trained on a set of synthetically generated pairs $(X,θ)$ and learns to do in-context learning (ICL) by adapting to unknown $π$. Theoretically, we show that a sufficiently wide transformer can achieve vanishing regret with respect to an oracle estimator who knows $π$ as dimension grows to infinity. Practically, we discover that already very small models (100k parameters) are able to outperform the best classical algorithm (non-parametric maximum likelihood, or NPMLE) both in runtime and validation loss, which we compute on out-of-distribution synthetic data as well as real-world datasets (NHL hockey, MLB baseball, BookCorpusOpen). Finally, by using linear probes, we confirm that the transformer's EB estimator appears to internally work differently from either NPMLE or Robbins' estimators.

Solving Empirical Bayes via Transformers

TL;DR

This work shows that encoder-only transformers pretrained on synthetic Poisson-EB data can perform in-context learning to estimate high-dimensional Poisson means with unknown priors. The authors prove universal-approximation-type results for EB estimators (Robbins and NPMLE) and demonstrate that small models can outperform NPMLE in both runtime and validation loss on synthetic and real data, including NHL, MLB, and BookCorpusOpen, while exhibiting length generalization. They further reveal through linear probes that the transformer estimators implement Bayes-like shrinkage behavior distinct from classical EB estimators. Practically, this approach offers a fast, scalable alternative for empirical Bayes inference in large-scale problems, with potential extensions to multi-dimensional settings and posterior uncertainty quantification.

Abstract

This work applies modern AI tools (transformers) to solving one of the oldest statistical problems: Poisson means under empirical Bayes (Poisson-EB) setting. In Poisson-EB a high-dimensional mean vector (with iid coordinates sampled from an unknown prior ) is estimated on the basis of . A transformer model is pre-trained on a set of synthetically generated pairs and learns to do in-context learning (ICL) by adapting to unknown . Theoretically, we show that a sufficiently wide transformer can achieve vanishing regret with respect to an oracle estimator who knows as dimension grows to infinity. Practically, we discover that already very small models (100k parameters) are able to outperform the best classical algorithm (non-parametric maximum likelihood, or NPMLE) both in runtime and validation loss, which we compute on out-of-distribution synthetic data as well as real-world datasets (NHL hockey, MLB baseball, BookCorpusOpen). Finally, by using linear probes, we confirm that the transformer's EB estimator appears to internally work differently from either NPMLE or Robbins' estimators.

Paper Structure

This paper contains 32 sections, 6 theorems, 32 equations, 6 figures, 12 tables, 1 algorithm.

Key Result

Theorem 4.1

Set a positive integer $d$ and a positive real number $M$. Then for any $\epsilon > 0$, there exists a transformer that learns the clipped Robbins estimator $\hat{\theta}_{\mathsf{Rob}, d, M}$ up to a precision $\epsilon$, In addition, this transformer has embedding dimension $d+1$ at encoding stage

Figures (6)

  • Figure 1: (a), (b):$R^2$ score of linear probe result against $N(x), f_{\hat{\pi}}(x)$ and $x$ for T24r. We see that while $x$ itself is easily recoverable from any layer, "knowledge" about the former two quantities appears to either decrease (in (a)) or plateau (in (b)) with depth. (c) In the multinomial prior case, T24r does not seem to use any information on the atom weight $\text{PMF}_{\pi}(\theta)$.
  • Figure 2: Average regret of NPMLE vs transformers on various priors in $\mathcal{P}([0, 50])$ and on various sequence lengths. The regrets of T24r and L24r decrease with sequence length on all priors and beats NPMLE on most instances. Red vertical lines at 512 denotes the sequence length that the transformers are trained on. (a). T24r outperforms L24r at longer sequence length. ERM-monotone's, MLE's, and Robbin's regrets are 3.26, 11.73, and 85.68 even at sequence length 2048. (b). L24r outperforms T24r throughout, although the gap narrows at $n=2048$. ERM-monotone's, MLE's, GS's, and Robbin's regrets are 2.57, 5.87, 5.63, and 64.43 even at sequence length 2048. (c). NPMLE generalizes better at longer sequence lengths. At $n=4096$ (not shown) NPMLE beats the best performing T24f (regret 0.104 vs 0.153). T24r has around 40% extra regret compared to T24f throughout; L24r has 10% extra regret compared to T24r at $n=2048$. ERM-monotone's, MLE's, GS's and Robbin's regrets are 2.36, 14.82, 14.66, and 69.40 even at sequence length 2048.
  • Figure 3: Average time (in seconds) per batch vs sequence length, showing that the inference time of T24r is comparable with that of ERM monotone, and 100x faster than NPMLE. Also shown is L24r which scales better at $n=2048$.
  • Figure 4: Violin plots of RMSE ratio of ERM-monotone (blue), NPMLE (orange), T24r (green), and L24r (red) over MLE over various datasets. The horizontal line at 1.0 indicates the threshold of MLE's RMSE. Robbins is not shown due to the wide variance. Both T24r and L24r show a general improvement over MLE, and in the case of the MLB dataset their advantage over NPMLE is visible.
  • Figure 5: Discussion on Worst Prior
  • ...and 1 more figures

Theorems & Definitions (14)

  • Theorem 4.1
  • Theorem 4.2
  • Corollary 4.3
  • Definition A.1: Worst-case prior
  • Lemma A.2
  • Corollary A.3
  • proof
  • proof : Proof of thm:robbins-transformers
  • Remark B.1
  • Lemma B.2
  • ...and 4 more