Table of Contents
Fetching ...

Latent State Models of Training Dynamics

Michael Y. Hu, Angelica Chen, Naomi Saphra, Kyunghyun Cho

TL;DR

The paper addresses the unclear role of randomness in neural network training by modeling training trajectories with a Gaussian hidden Markov model over a rich set of metrics collected during training. It derives a low-dimensional training map that represents latent states and transitions, enabling semantic interpretation of training dynamics and phase transitions such as grokking. A regression-based procedure assigns meaning to latent states (including detour states) and links them to convergence time, while experiments across modular arithmetic, sparse parities, language modeling, and image classification demonstrate the method’s ability to predict convergence and reveal how hyperparameters and architectures modulate seed sensitivity. The work provides a principled, automated framework for understanding variability in training and suggests practical pathways to stabilize learning by avoiding detour states through targeted architectural and optimization choices.

Abstract

The impact of randomness on model training is poorly understood. How do differences in data order and initialization actually manifest in the model, such that some training runs outperform others or converge faster? Furthermore, how can we interpret the resulting training dynamics and the phase transitions that characterize different trajectories? To understand the effect of randomness on the dynamics and outcomes of neural network training, we train models multiple times with different random seeds and compute a variety of metrics throughout training, such as the $L_2$ norm, mean, and variance of the neural network's weights. We then fit a hidden Markov model (HMM) over the resulting sequences of metrics. The HMM represents training as a stochastic process of transitions between latent states, providing an intuitive overview of significant changes during training. Using our method, we produce a low-dimensional, discrete representation of training dynamics on grokking tasks, image classification, and masked language modeling. We use the HMM representation to study phase transitions and identify latent "detour" states that slow down convergence.

Latent State Models of Training Dynamics

TL;DR

The paper addresses the unclear role of randomness in neural network training by modeling training trajectories with a Gaussian hidden Markov model over a rich set of metrics collected during training. It derives a low-dimensional training map that represents latent states and transitions, enabling semantic interpretation of training dynamics and phase transitions such as grokking. A regression-based procedure assigns meaning to latent states (including detour states) and links them to convergence time, while experiments across modular arithmetic, sparse parities, language modeling, and image classification demonstrate the method’s ability to predict convergence and reveal how hyperparameters and architectures modulate seed sensitivity. The work provides a principled, automated framework for understanding variability in training and suggests practical pathways to stabilize learning by avoiding detour states through targeted architectural and optimization choices.

Abstract

The impact of randomness on model training is poorly understood. How do differences in data order and initialization actually manifest in the model, such that some training runs outperform others or converge faster? Furthermore, how can we interpret the resulting training dynamics and the phase transitions that characterize different trajectories? To understand the effect of randomness on the dynamics and outcomes of neural network training, we train models multiple times with different random seeds and compute a variety of metrics throughout training, such as the norm, mean, and variance of the neural network's weights. We then fit a hidden Markov model (HMM) over the resulting sequences of metrics. The HMM represents training as a stochastic process of transitions between latent states, providing an intuitive overview of significant changes during training. Using our method, we produce a low-dimensional, discrete representation of training dynamics on grokking tasks, image classification, and masked language modeling. We use the HMM representation to study phase transitions and identify latent "detour" states that slow down convergence.
Paper Structure (31 sections, 7 equations, 11 figures, 10 tables)

This paper contains 31 sections, 7 equations, 11 figures, 10 tables.

Figures (11)

  • Figure 1: From training runs we collect metrics, which are functions of the neural networks' weights. We then train a hidden Markov model to predict the sequences of metrics generated from the training runs. The hidden Markov model learns a discrete latent state over the sequence, which we use to cluster and analyze the training trajectory.
  • Figure 2: One-layer transformer trained on modular addition. Edges exiting the initialization state 1 all have different mean convergence epochs. Feature changes are ordered by importance from most to least. For example, "$L_2 \downarrow$0.59" means that state 2 has a learned $L_2$ norm that is 0.59 standard deviations lower than state 1, and the $L_2$ norm is the most important feature for state 2. See Appendix \ref{['appendix:metrics']} for a glossary of metrics and Section \ref{['sec:math']} for how we identify important features.
  • Figure 3: ResNet18 trained on CIFAR-100. All 40 training runs we collected from CIFAR-100 follow the same path, although individual runs can spend slightly different amounts of time in each state. As shown by the table, the training dynamics of CIFAR-100 are similar between states.
  • Figure 4: Without residual connections and batch normalization, ResNet training becomes unstable, causing convergence times to differ significantly. Slow-generalizing runs take the state transition $(3 \to 1)$, while fast-generalizing runs take the state transition $(3\to2)$. (Runs can take the path $(3 \to 1 \to 3 \to 2)$, so transition frequencies do not sum to 40). The variability induced by removing residual connections and batch norm occurs at the beginning of training.
  • Figure 5: With layer normalization and a lower learning rate, the one-layer transformer quickly learns the modular arithmetic task, with a convergence time stable across random seed. This stability is captured by the linear training map. Critically, the map still reflects the grokking phase transitions: memorization, which occurs in state 0, and generalization, which occurs in state 2.
  • ...and 6 more figures