Latent State Models of Training Dynamics
Michael Y. Hu, Angelica Chen, Naomi Saphra, Kyunghyun Cho
TL;DR
The paper addresses the unclear role of randomness in neural network training by modeling training trajectories with a Gaussian hidden Markov model over a rich set of metrics collected during training. It derives a low-dimensional training map that represents latent states and transitions, enabling semantic interpretation of training dynamics and phase transitions such as grokking. A regression-based procedure assigns meaning to latent states (including detour states) and links them to convergence time, while experiments across modular arithmetic, sparse parities, language modeling, and image classification demonstrate the method’s ability to predict convergence and reveal how hyperparameters and architectures modulate seed sensitivity. The work provides a principled, automated framework for understanding variability in training and suggests practical pathways to stabilize learning by avoiding detour states through targeted architectural and optimization choices.
Abstract
The impact of randomness on model training is poorly understood. How do differences in data order and initialization actually manifest in the model, such that some training runs outperform others or converge faster? Furthermore, how can we interpret the resulting training dynamics and the phase transitions that characterize different trajectories? To understand the effect of randomness on the dynamics and outcomes of neural network training, we train models multiple times with different random seeds and compute a variety of metrics throughout training, such as the $L_2$ norm, mean, and variance of the neural network's weights. We then fit a hidden Markov model (HMM) over the resulting sequences of metrics. The HMM represents training as a stochastic process of transitions between latent states, providing an intuitive overview of significant changes during training. Using our method, we produce a low-dimensional, discrete representation of training dynamics on grokking tasks, image classification, and masked language modeling. We use the HMM representation to study phase transitions and identify latent "detour" states that slow down convergence.
