Table of Contents
Fetching ...

Targeted stochastic gradient Markov chain Monte Carlo for hidden Markov models with rare latent states

Rihui Ou, Deborshee Sen, Alexander L Young, David B Dunson

TL;DR

This work tackles scalable Bayesian inference for discrete-state HMMs in the presence of rare latent states, where traditional forward-backward MCMC is expensive on long sequences. The authors introduce Targeted Sub-Sampling for HMMs (TASS-HMM), a stochastic gradient MCMC framework that over-samples subsequences containing rare-state data via component-specific importance weights derived from clustering-based approximations. By minimizing gradient-variance with targeted weights and leveraging per-component sampling, TASS-HMM achieves more accurate rare-state parameter estimation and improves predictive performance, demonstrated on synthetic experiments and real data like solar flares and sleep cycles. The approach offers a practical route to reliable uncertainty quantification and faster convergence in scenarios with imbalanced latent-state information, with potential extensions to other stochastic-gradient inference methods.

Abstract

Markov chain Monte Carlo (MCMC) algorithms for hidden Markov models often rely on the forward-backward sampler. This makes them computationally slow as the length of the time series increases, motivating the development of sub-sampling-based approaches. These approximate the full posterior by using small random subsequences of the data at each MCMC iteration within stochastic gradient MCMC. In the presence of imbalanced data resulting from rare latent states, subsequences often exclude rare latent state data, leading to inaccurate inference and prediction/detection of rare events. We propose a targeted sub-sampling (TASS) approach that over-samples observations corresponding to rare latent states when calculating the stochastic gradient of parameters associated with them. TASS uses an initial clustering of the data to construct subsequence weights that reduce the variance in gradient estimation. This leads to improved sampling efficiency, in particular in settings where the rare latent states correspond to extreme observations. We demonstrate substantial gains in predictive and inferential accuracy on real and synthetic examples.

Targeted stochastic gradient Markov chain Monte Carlo for hidden Markov models with rare latent states

TL;DR

This work tackles scalable Bayesian inference for discrete-state HMMs in the presence of rare latent states, where traditional forward-backward MCMC is expensive on long sequences. The authors introduce Targeted Sub-Sampling for HMMs (TASS-HMM), a stochastic gradient MCMC framework that over-samples subsequences containing rare-state data via component-specific importance weights derived from clustering-based approximations. By minimizing gradient-variance with targeted weights and leveraging per-component sampling, TASS-HMM achieves more accurate rare-state parameter estimation and improves predictive performance, demonstrated on synthetic experiments and real data like solar flares and sleep cycles. The approach offers a practical route to reliable uncertainty quantification and faster convergence in scenarios with imbalanced latent-state information, with potential extensions to other stochastic-gradient inference methods.

Abstract

Markov chain Monte Carlo (MCMC) algorithms for hidden Markov models often rely on the forward-backward sampler. This makes them computationally slow as the length of the time series increases, motivating the development of sub-sampling-based approaches. These approximate the full posterior by using small random subsequences of the data at each MCMC iteration within stochastic gradient MCMC. In the presence of imbalanced data resulting from rare latent states, subsequences often exclude rare latent state data, leading to inaccurate inference and prediction/detection of rare events. We propose a targeted sub-sampling (TASS) approach that over-samples observations corresponding to rare latent states when calculating the stochastic gradient of parameters associated with them. TASS uses an initial clustering of the data to construct subsequence weights that reduce the variance in gradient estimation. This leads to improved sampling efficiency, in particular in settings where the rare latent states correspond to extreme observations. We demonstrate substantial gains in predictive and inferential accuracy on real and synthetic examples.

Paper Structure

This paper contains 23 sections, 3 theorems, 41 equations, 11 figures, 5 tables, 1 algorithm.

Key Result

Proposition 1

The optimal importance weights when the latent states are known are given by for $n=1,\dots,N$, where $\theta = (\theta_1,\dots,\theta_d)$.

Figures (11)

  • Figure 1: Distance between the true and sampled parameter values at different iterations. $|\sigma_3^2-\sigma^2_{3, \text{true}}|$ versus iteration number (left); $|\mu_3-\mu_{3, \text{true}}|$ versus iteration number (right).
  • Figure 2: Log predictive density of held-out data for one rare latent state example.
  • Figure 3: Distance between the true and sampled parameter values at different iterations when two rare latent states are present. $|\mu_2-\mu_{2, \text{true}}|$ versus iteration number (left); $|\mu_3-\mu_{3, \text{true}}|$ versus iteration number (right).
  • Figure 4: Log predictive density of held-out data for state two (left) and state three (right) when two rare latent states are present.
  • Figure 5: Log predictive density of held-out data in the non-rare setting.
  • ...and 6 more figures

Theorems & Definitions (6)

  • Proposition 1: Optimal weights for known latent states
  • proof : Proof of \ref{['prop.imp.known.latent']}
  • Corollary 1: Component-specific optimal importance weights for known latent states
  • proof
  • Proposition 2
  • proof