Table of Contents
Fetching ...

Belief Net: A Filter-Based Framework for Learning Hidden Markov Models from Observations

Reginald Zhiyan Chen, Heng-Sheng Chang, Prashant G. Mehta

TL;DR

Belief Net introduces a gradient-based, differentiable forward-filter for learning HMM parameters, recasting the forward pass as a structured neural network with logits for $\mu$, $A$, and $C$. Trained via standard autoregressive loss, it yields interpretable, parameterizable HMMs and competitive predictive performance. Empirical results show faster convergence and robust parameter recovery on synthetic data, including overcomplete regimes where spectral methods fail, and meaningful latent structure extraction from real-world text. The approach provides a principled alternative to EM and black-box Transformers, with potential extensions to broader state-space and decision-making settings.

Abstract

Hidden Markov Models (HMMs) are fundamental for modeling sequential data, yet learning their parameters from observations remains challenging. Classical methods like the Baum-Welch (EM) algorithm are computationally intensive and prone to local optima, while modern spectral algorithms offer provable guarantees but may produce probability outputs outside valid ranges. This work introduces Belief Net, a novel framework that learns HMM parameters through gradient-based optimization by formulating the HMM's forward filter as a structured neural network. Unlike black-box Transformer models, Belief Net's learnable weights are explicitly the logits of the initial distribution, transition matrix, and emission matrix, ensuring full interpretability. The model processes observation sequences using a decoder-only architecture and is trained end-to-end with standard autoregressive next-observation prediction loss. On synthetic HMM data, Belief Net achieves superior convergence speed compared to Baum-Welch, successfully recovering parameters in both undercomplete and overcomplete settings where spectral methods fail. Comparisons with Transformer-based models are also presented on real-world language data.

Belief Net: A Filter-Based Framework for Learning Hidden Markov Models from Observations

TL;DR

Belief Net introduces a gradient-based, differentiable forward-filter for learning HMM parameters, recasting the forward pass as a structured neural network with logits for , , and . Trained via standard autoregressive loss, it yields interpretable, parameterizable HMMs and competitive predictive performance. Empirical results show faster convergence and robust parameter recovery on synthetic data, including overcomplete regimes where spectral methods fail, and meaningful latent structure extraction from real-world text. The approach provides a principled alternative to EM and black-box Transformers, with potential extensions to broader state-space and decision-making settings.

Abstract

Hidden Markov Models (HMMs) are fundamental for modeling sequential data, yet learning their parameters from observations remains challenging. Classical methods like the Baum-Welch (EM) algorithm are computationally intensive and prone to local optima, while modern spectral algorithms offer provable guarantees but may produce probability outputs outside valid ranges. This work introduces Belief Net, a novel framework that learns HMM parameters through gradient-based optimization by formulating the HMM's forward filter as a structured neural network. Unlike black-box Transformer models, Belief Net's learnable weights are explicitly the logits of the initial distribution, transition matrix, and emission matrix, ensuring full interpretability. The model processes observation sequences using a decoder-only architecture and is trained end-to-end with standard autoregressive next-observation prediction loss. On synthetic HMM data, Belief Net achieves superior convergence speed compared to Baum-Welch, successfully recovering parameters in both undercomplete and overcomplete settings where spectral methods fail. Comparisons with Transformer-based models are also presented on real-world language data.

Paper Structure

This paper contains 40 sections, 9 equations, 8 figures, 3 tables, 2 algorithms.

Figures (8)

  • Figure 1: Belief Net architecture. The model initializes at $t=0$ with $\mu$ as prior and maintains a belief state $\mu_t$ over the sequence $t\in\left\{0,1,\dots,T-1\right\}$ by recursively applying transition (using transition matrix $A$) and correction (using previous prior $\mu_{t|t-1}$ and $e_t$ from emission step based on observation $Z_t$ and emission matrix $C$) steps to update beliefs. The estimation step is to predict probabilities of the next observation $p_{t+1}$ based on the next prior $\mu_{t+1|t}$ and emission matrix $C$. The detailed computation of each step is described in Algorithm \ref{['alg:belief_net']}.
  • Figure 2: Parameter recovery on synthetic data. The validation loss $\mathsf{J}$ is plotted with respect to candidate state dimensions $\hat{\mathsf{d}}$ for each method. The true state dimension is $\mathsf{d} = 64$ and the gray area indicates the $\hat{\mathsf{d}} \geq \mathsf{d}$ regime. Curves correspond to models (colors): Baum-Welch (blue), Spectral (green), two nanoGPTs (oranges), and Belief Net (red). Dashed lines: random guess (gray) and HMM filter (black) represents worst and best scenarios, respectively.
  • Figure 3: Language modeling results on Federalist Papers. The Belief Net's training (light red) and validation (red) loss $\mathsf{J}_l$ over iterations $l$ are shown in solid curves. The validation loss is evaluated every fifty iterations. For comparison, horizontal dashed lines show the final validation losses achieved by other methods, including random guess (gray), Baum-Welch (blue), Spectral (green), and two nanoGPTs (oranges). Corresponding $\mathsf{Perplexity}$ is shown on the right.
  • Figure 4: Undercomplete results on synthetic data. The Belief Net's training (light red) and validation (red) loss $\mathsf{J}_l$ over iterations $l$ are shown in solid curves. The validation loss is evaluated every fifty iterations. For comparison, horizontal dashed lines show the final validation losses achieved by other methods, including random guess (gray), Baum-Welch (blue), Spectral (green), and nanoGPT (orange).
  • Figure 5: Overcomplete results on synthetic data. The Belief Net's training (light red) and validation (red) loss $\mathsf{J}_l$ over iterations $l$ are shown in solid curves. The validation loss is evaluated every fifty iterations. For comparison, horizontal dashed lines show the final validation losses achieved by other methods, including random guess (gray), Baum-Welch (blue), Spectral (green), and nanoGPT (orange).
  • ...and 3 more figures

Theorems & Definitions (3)

  • remark 1
  • remark 2
  • remark 3