BayesFormer: Transformer with Uncertainty Estimation
Karthik Abinav Sankararaman, Sinong Wang, Han Fang
TL;DR
BayesFormer addresses the need for principled uncertainty quantification in Transformer models by formulating dropout through approximate variational inference, yielding a Bayesian posterior over weights. It defines a variational family $q_{oldsymbol{M}}(oldsymbol{W})$ and an unbiased estimator via stochastic forward passes, with dropout applied strategically to inputs, positional embeddings, and attention components to produce a tractable predictive distribution $p(oldsymbol{y}^ star|oldsymbol{X}^ star) approx rac{1}{T} ext{sum}_{t=1}^T oldsymbol{f}_{oldsymbol{y},oldsymbol{ ilde{W}}_t}(oldsymbol{X}^ star)$. The approach yields improved performance and robustness across language modeling, GLUE classification, long-range sequence tasks, machine translation, and active learning, while providing practical uncertainty estimates that can drive acquisition and robust decision making. It also demonstrates compatibility with efficient transformer variants (x-formers), suggesting broad applicability and a path toward integrating principled uncertainty into large-scale pretrained models.
Abstract
Transformer has become ubiquitous due to its dominant performance in various NLP and image processing tasks. However, it lacks understanding of how to generate mathematically grounded uncertainty estimates for transformer architectures. Models equipped with such uncertainty estimates can typically improve predictive performance, make networks robust, avoid over-fitting and used as acquisition function in active learning. In this paper, we introduce BayesFormer, a Transformer model with dropouts designed by Bayesian theory. We proposed a new theoretical framework to extend the approximate variational inference-based dropout to Transformer-based architectures. Through extensive experiments, we validate the proposed architecture in four paradigms and show improvements across the board: language modeling and classification, long-sequence understanding, machine translation and acquisition function for active learning.
