Enhancing Neural Training via a Correlated Dynamics Model
Jonathan Brokman, Roy Betser, Rotem Turjeman, Tom Berkov, Ido Cohen, Guy Gilboa
TL;DR
The paper identifies that neural network weight trajectories during SGD exhibit strong correlations that can be captured by a small set of correlated modes. It introduces Correlation Mode Decomposition (CMD), a data-driven approach that clusters weight trajectories into modes and represents each weight as w_i(t) ≈ a_i w_m(t) + b_i using mode references w_m(t); it also provides online CMD and embedded CMD variants to integrate modeling with training. Compared to state-of-the-art low-dimensional dynamics methods such as P-BFGS, CMD achieves superior or competitive accuracy across diverse architectures (e.g., ResNet18, WideResNet, ViT-b-16) and tasks, while enabling substantial reductions in communication in Federated Learning by transmitting only embedded coefficients and a few mode references. CMD also offers visualization benefits by enabling smooth loss/accuracy landscapes in reduced-dimensional spaces, and it acts as an implicit regularizer comparable to EMA/SWA. The practical impact includes improved training efficiency, reduced communication overhead, and broader applicability to real-world distributed learning scenarios.
Abstract
As neural networks grow in scale, their training becomes both computationally demanding and rich in dynamics. Amidst the flourishing interest in these training dynamics, we present a novel observation: Parameters during training exhibit intrinsic correlations over time. Capitalizing on this, we introduce Correlation Mode Decomposition (CMD). This algorithm clusters the parameter space into groups, termed modes, that display synchronized behavior across epochs. This enables CMD to efficiently represent the training dynamics of complex networks, like ResNets and Transformers, using only a few modes. Moreover, test set generalization is enhanced. We introduce an efficient CMD variant, designed to run concurrently with training. Our experiments indicate that CMD surpasses the state-of-the-art method for compactly modeled dynamics on image classification. Our modeling can improve training efficiency and lower communication overhead, as shown by our preliminary experiments in the context of federated learning.
