Belief Net: A Filter-Based Framework for Learning Hidden Markov Models from Observations
Reginald Zhiyan Chen, Heng-Sheng Chang, Prashant G. Mehta
TL;DR
Belief Net introduces a gradient-based, differentiable forward-filter for learning HMM parameters, recasting the forward pass as a structured neural network with logits for $\mu$, $A$, and $C$. Trained via standard autoregressive loss, it yields interpretable, parameterizable HMMs and competitive predictive performance. Empirical results show faster convergence and robust parameter recovery on synthetic data, including overcomplete regimes where spectral methods fail, and meaningful latent structure extraction from real-world text. The approach provides a principled alternative to EM and black-box Transformers, with potential extensions to broader state-space and decision-making settings.
Abstract
Hidden Markov Models (HMMs) are fundamental for modeling sequential data, yet learning their parameters from observations remains challenging. Classical methods like the Baum-Welch (EM) algorithm are computationally intensive and prone to local optima, while modern spectral algorithms offer provable guarantees but may produce probability outputs outside valid ranges. This work introduces Belief Net, a novel framework that learns HMM parameters through gradient-based optimization by formulating the HMM's forward filter as a structured neural network. Unlike black-box Transformer models, Belief Net's learnable weights are explicitly the logits of the initial distribution, transition matrix, and emission matrix, ensuring full interpretability. The model processes observation sequences using a decoder-only architecture and is trained end-to-end with standard autoregressive next-observation prediction loss. On synthetic HMM data, Belief Net achieves superior convergence speed compared to Baum-Welch, successfully recovering parameters in both undercomplete and overcomplete settings where spectral methods fail. Comparisons with Transformer-based models are also presented on real-world language data.
