Table of Contents
Fetching ...

Enhancing Neural Training via a Correlated Dynamics Model

Jonathan Brokman, Roy Betser, Rotem Turjeman, Tom Berkov, Ido Cohen, Guy Gilboa

TL;DR

The paper identifies that neural network weight trajectories during SGD exhibit strong correlations that can be captured by a small set of correlated modes. It introduces Correlation Mode Decomposition (CMD), a data-driven approach that clusters weight trajectories into modes and represents each weight as w_i(t) ≈ a_i w_m(t) + b_i using mode references w_m(t); it also provides online CMD and embedded CMD variants to integrate modeling with training. Compared to state-of-the-art low-dimensional dynamics methods such as P-BFGS, CMD achieves superior or competitive accuracy across diverse architectures (e.g., ResNet18, WideResNet, ViT-b-16) and tasks, while enabling substantial reductions in communication in Federated Learning by transmitting only embedded coefficients and a few mode references. CMD also offers visualization benefits by enabling smooth loss/accuracy landscapes in reduced-dimensional spaces, and it acts as an implicit regularizer comparable to EMA/SWA. The practical impact includes improved training efficiency, reduced communication overhead, and broader applicability to real-world distributed learning scenarios.

Abstract

As neural networks grow in scale, their training becomes both computationally demanding and rich in dynamics. Amidst the flourishing interest in these training dynamics, we present a novel observation: Parameters during training exhibit intrinsic correlations over time. Capitalizing on this, we introduce Correlation Mode Decomposition (CMD). This algorithm clusters the parameter space into groups, termed modes, that display synchronized behavior across epochs. This enables CMD to efficiently represent the training dynamics of complex networks, like ResNets and Transformers, using only a few modes. Moreover, test set generalization is enhanced. We introduce an efficient CMD variant, designed to run concurrently with training. Our experiments indicate that CMD surpasses the state-of-the-art method for compactly modeled dynamics on image classification. Our modeling can improve training efficiency and lower communication overhead, as shown by our preliminary experiments in the context of federated learning.

Enhancing Neural Training via a Correlated Dynamics Model

TL;DR

The paper identifies that neural network weight trajectories during SGD exhibit strong correlations that can be captured by a small set of correlated modes. It introduces Correlation Mode Decomposition (CMD), a data-driven approach that clusters weight trajectories into modes and represents each weight as w_i(t) ≈ a_i w_m(t) + b_i using mode references w_m(t); it also provides online CMD and embedded CMD variants to integrate modeling with training. Compared to state-of-the-art low-dimensional dynamics methods such as P-BFGS, CMD achieves superior or competitive accuracy across diverse architectures (e.g., ResNet18, WideResNet, ViT-b-16) and tasks, while enabling substantial reductions in communication in Federated Learning by transmitting only embedded coefficients and a few mode references. CMD also offers visualization benefits by enabling smooth loss/accuracy landscapes in reduced-dimensional spaces, and it acts as an implicit regularizer comparable to EMA/SWA. The practical impact includes improved training efficiency, reduced communication overhead, and broader applicability to real-world distributed learning scenarios.

Abstract

As neural networks grow in scale, their training becomes both computationally demanding and rich in dynamics. Amidst the flourishing interest in these training dynamics, we present a novel observation: Parameters during training exhibit intrinsic correlations over time. Capitalizing on this, we introduce Correlation Mode Decomposition (CMD). This algorithm clusters the parameter space into groups, termed modes, that display synchronized behavior across epochs. This enables CMD to efficiently represent the training dynamics of complex networks, like ResNets and Transformers, using only a few modes. Moreover, test set generalization is enhanced. We introduce an efficient CMD variant, designed to run concurrently with training. Our experiments indicate that CMD surpasses the state-of-the-art method for compactly modeled dynamics on image classification. Our modeling can improve training efficiency and lower communication overhead, as shown by our preliminary experiments in the context of federated learning.
Paper Structure (47 sections, 27 equations, 22 figures, 10 tables, 4 algorithms)

This paper contains 47 sections, 27 equations, 22 figures, 10 tables, 4 algorithms.

Figures (22)

  • Figure 1: Correlated dynamics. Left: Correlation matrix of weight trajectories, clustered to modes, for a simple network (See Fig. \ref{['cifar10_modes_dist:sub1']} in the Appendix) trained on MNIST (lecun1998gradient). Middle: Clustered correlation matrix of weight trajectories for ResNet18 (he2016deep) trained on CIFAR10 (krizhevsky2009learning). The block diagonal structure indicates high correlations of the parameter dynamics within each mode. Right: Three weight trajectories from SGD training are reconstructed via DMD, which uses exponentials (Top) and CMD, which uses reference trajectories (Bottom). These highlight the effectiveness of CMD in capturing complex training dynamics. See Appendix \ref{['sec: manual inspection']}, Figs. \ref{['fig:corr1 corr2']}, \ref{['fig:w_rec_appendix']} for extended findings.
  • Figure 2: Post-hoc CMD examples. CMD representation is highly general and models well diverse architectures operating on various tasks. Here, the successful modeling is evident through its ability to follow the performance dynamics while even providing a performance boost. Left: CIFAR10 performance on ViT-b-16 dosovitskiy2020image, pre-trained on JFT-300M. Middle: StarGAN-v2 choi2020stargan style transfer, qualitative result of the original network compared to the result of CMD modeling of its dynamics. Right: Segmentation results, using PSPNet Architecture zhao2017pyramid. See more details and results in Appendix \ref{['sec:additional_posthoc']} and Figs. \ref{['fig:stargan-v2']}, \ref{['fig: single run test acc']}. 'GD' in the legend stands for SGD.
  • Figure 3: Embedding CMD in training. Left: Naively assigning CMD-modeled weights to the model every 20 epochs, and performing SGD in between. Dots represent the CMD model performance. Performance is repeatedly drawn back to the CMD-less (regular SGD) case. Middle: Online CMD and Embedded CMD, compared to regular SGD. An additional naive approach which trains only the reference weights post warm-up is also presented (green). Performance is portrayed along the training epochs. CMD runs use $M=10$ modes, $F=20$ warm-up epochs. We use an embedding rate of $P=10\%$, and $L=10$ epochs (Right), which lead to $\approx 50\%$ reduction in the total number of trained parameters in the whole training cycle. Both online and embedded CMD variants considerably improve performance, compared to regular SGD. Both naive approaches do not.
  • Figure 4: Visualization of accuracy landscape. A grid is created based on the two reference weights, and colored by accuracy values. The accuracy landscape of the training set (left) is very smooth, compared to that of the validation set (second left). The landscape of the Automobile class (3rd left) is rather regular whereas the Dog class (right) is much more irregular and suboptimal. The original CMD model is marked with a black dot, the optimal model is marked with a red dot.
  • Figure 5: Federated Learning, ResNet18 on CIFAR10. Test loss and accuracy throughout training. CMD with two modes is harnessed for FL, compared to regular SGD (Baseline) and Aggressive APF (A-APF). Each trend-line is the average of 10 runs.
  • ...and 17 more figures