Hidden Markov Neural Networks
Lorenzo Rimella, Nick Whiteley
TL;DR
Hidden Markov Neural Networks (HMNNs) introduce a time-evolving Bayesian framework that treats neural network weights as hidden states in a factorial hidden Markov model, enabling continual adaptation with principled forgetting. Inference is performed sequentially via variational filtering, using a forward prediction and correction scheme, and a sequential reparameterization trick to estimate gradients. The Gaussian mixture variational family with variational DropConnect provides robust regularization and scalable inference, with closed-form updates in the Gaussian case. Empirically, HMNNs demonstrate strong predictive performance and meaningful uncertainty quantification on time-varying tasks, including MNIST-like drift scenarios and one-step-ahead video frame prediction, while outperforming several continual-learning baselines in dynamic settings.
Abstract
We define an evolving in-time Bayesian neural network called a Hidden Markov Neural Network, which addresses the crucial challenge in time-series forecasting and continual learning: striking a balance between adapting to new data and appropriately forgetting outdated information. This is achieved by modelling the weights of a neural network as the hidden states of a Hidden Markov model, with the observed process defined by the available data. A filtering algorithm is employed to learn a variational approximation of the evolving-in-time posterior distribution over the weights. By leveraging a sequential variant of Bayes by Backprop, enriched with a stronger regularization technique called variational DropConnect, Hidden Markov Neural Networks achieve robust regularization and scalable inference. Experiments on MNIST, dynamic classification tasks, and next-frame forecasting in videos demonstrate that Hidden Markov Neural Networks provide strong predictive performance while enabling effective uncertainty quantification.
