Towards Probabilistically-Sound Beam Search with Masked Language Models
Creston Brooks, Robert Calef, Charlie Cowen-Breen, Anna Sappington
TL;DR
The paper tackles how to perform beam search with masked language models by addressing the absence of a readily available joint distribution $p(\mathbf{x})$ and proposing a probabilistically-sound framework based on the $Hammersley{-}Clifford{-}Besag$ (HCB) construction. It derives conditions under which the standard MLM infilling approximation is valid and introduces an adjustment term that preserves probabilistic soundness without additional forward passes, yielding the HCB beam search method. Empirically, HCB beam search demonstrates improvements over standard beam search across multiple models and domains, while ablations reveal the importance of context, pivots, and the context signal captured by $p([\mathop{M}]|\mathbf{x})$. These results enable more reliable text infilling for applications like ancient text restoration and protein engineering and offer practical guidance on pivot design and when HCB is advantageous.
Abstract
Beam search with masked language models (MLMs) is challenging in part because joint probability distributions over sequences are not readily available, unlike for autoregressive models. However, estimating such distributions has important domain-specific applications such as ancient text restoration and protein engineering. Here we present probabilistically-sound methods for beam search with MLMs. First, we clarify the conditions under which it is theoretically sound to perform text infilling with MLMs using standard beam search. When these conditions fail, we provide a probabilistically-sound inference time modification with no additional computational complexity and demonstrate that it is superior to the aforementioned beam search in the expected conditions. We then present empirical results comparing several infilling approaches with MLMs across several domains. Notably, our method probes the inductive biases of MLMs and explores the surprising contextual sensitivity of mask tokens for text infilling.
