Table of Contents
Fetching ...

Precise Dynamics of Diagonal Linear Networks: A Unifying Analysis by Dynamical Mean-Field Theory

Sota Nishiyama, Masaaki Imaizumi

TL;DR

This work addresses the challenge of understanding gradient-flow dynamics in diagonal linear networks (DLNs) by developing a unified dynamical mean-field theory (DMFT) framework that reduces high-dimensional learning dynamics to a low-dimensional, self-consistent process. The authors identify distinct dynamical regimes governed by initialization scale, derive fixed-point characterizations that interpolate between minimum-norm and $\ell_1$-biased solutions, and establish a precise trade-off between generalization and convergence speeds. They analyze learning timescales via singular perturbation theory, revealing lazy and rich phases for large initialization and search and descent phases for small initialization, with connections to grokking. A rigorous theory for truncated DLNs is provided, alongside extensive numerical validation on Gaussian and real data, demonstrating the robustness and universality of the DMFT predictions. Overall, the work demonstrates DMFT as a powerful tool for predicting implicit biases and temporal structure in high-dimensional neural dynamics with implications for initialization and training efficiency across architectures.

Abstract

Diagonal linear networks (DLNs) are a tractable model that captures several nontrivial behaviors in neural network training, such as initialization-dependent solutions and incremental learning. These phenomena are typically studied in isolation, leaving the overall dynamics insufficiently understood. In this work, we present a unified analysis of various phenomena in the gradient flow dynamics of DLNs. Using Dynamical Mean-Field Theory (DMFT), we derive a low-dimensional effective process that captures the asymptotic gradient flow dynamics in high dimensions. Analyzing this effective process yields new insights into DLN dynamics, including loss convergence rates and their trade-off with generalization, and systematically reproduces many of the previously observed phenomena. These findings deepen our understanding of DLNs and demonstrate the effectiveness of the DMFT approach in analyzing high-dimensional learning dynamics of neural networks.

Precise Dynamics of Diagonal Linear Networks: A Unifying Analysis by Dynamical Mean-Field Theory

TL;DR

This work addresses the challenge of understanding gradient-flow dynamics in diagonal linear networks (DLNs) by developing a unified dynamical mean-field theory (DMFT) framework that reduces high-dimensional learning dynamics to a low-dimensional, self-consistent process. The authors identify distinct dynamical regimes governed by initialization scale, derive fixed-point characterizations that interpolate between minimum-norm and -biased solutions, and establish a precise trade-off between generalization and convergence speeds. They analyze learning timescales via singular perturbation theory, revealing lazy and rich phases for large initialization and search and descent phases for small initialization, with connections to grokking. A rigorous theory for truncated DLNs is provided, alongside extensive numerical validation on Gaussian and real data, demonstrating the robustness and universality of the DMFT predictions. Overall, the work demonstrates DMFT as a powerful tool for predicting implicit biases and temporal structure in high-dimensional neural dynamics with implications for initialization and training efficiency across architectures.

Abstract

Diagonal linear networks (DLNs) are a tractable model that captures several nontrivial behaviors in neural network training, such as initialization-dependent solutions and incremental learning. These phenomena are typically studied in isolation, leaving the overall dynamics insufficiently understood. In this work, we present a unified analysis of various phenomena in the gradient flow dynamics of DLNs. Using Dynamical Mean-Field Theory (DMFT), we derive a low-dimensional effective process that captures the asymptotic gradient flow dynamics in high dimensions. Analyzing this effective process yields new insights into DLN dynamics, including loss convergence rates and their trade-off with generalization, and systematically reproduces many of the previously observed phenomena. These findings deepen our understanding of DLNs and demonstrate the effectiveness of the DMFT approach in analyzing high-dimensional learning dynamics of neural networks.

Paper Structure

This paper contains 82 sections, 9 theorems, 202 equations, 16 figures, 1 table.

Key Result

Theorem 1

Assume that the entries $\bm{X}=(x_{ij})_{i\in[n],j\in[d]}$ are independent and satisfy $\mathop{\mathrm{\mathbb{E}}}\nolimits x_{ij}=0,\mathop{\mathrm{\mathbb{E}}}\nolimits x_{ij}^2=1/d,\lVert x_{ij}\rVert_{\psi_2}\leq C/\sqrt{d}$, where $\lVert\cdot\rVert_{\psi_2}$ is the sub-Gaussian norm and $C> almost surely as $n,d \to \infty$, where $\xrightarrow{W_2}$ denotes convergence in the Wasserstein

Figures (16)

  • Figure 1: Schematic illustrations of the timescale structures of gradient flow dynamics in DLNs.
  • Figure 2: Long-time behaviors of DLNs for $\lambda = 0$ and $\delta = 0.5$. (a): Smaller initialization $\alpha$ leads to better generalization at the fixed point (Result \ref{['res:fixedpoint']} Case (iii)). Simulations are run $10$ times on independent data, and the error bars indicate one standard deviation. (b), (c): Smaller initialization $\alpha$ leads to slower convergence (Result \ref{['res:convergence_rates']}), thus showing a trade-off with the generalization performance. The slopes of the dashed lines represent theoretical predictions for the convergence rate. Experimental details are discussed in Section \ref{['sec:numerical']}.
  • Figure 3: (a): Training error dynamics for various initialization scales $\alpha$. Plots are simulations of DLNs with $d = 200$. We observe qualitatively different dynamics depicted in Figure \ref{['fig:timescales']}. The monotonicity of the training error changes at around $\alpha \approx 0.3$. (b): Test error dynamics for large $\alpha$. Once the time is rescaled by $\log(\alpha)$, the transition times to the second dynamical regime collapse, showing that it is the correct scaling for the transition time. (c): Training error dynamics for small $\alpha$. Once the time is rescaled by $\log(1/\alpha)$, descent phases start and proceed on the same timescales.
  • Figure 4: Training and test error dynamics for large $\alpha$ simulated with $d=200$.
  • Figure 5: Training and test error dynamics for large $\alpha$, with time rescaled by $\alpha^2$. The initial descent of training and test errors collapses onto the limiting solution \ref{['eq:lazy_sol']}
  • ...and 11 more figures

Theorems & Definitions (14)

  • Theorem : Informal version of Corollary \ref{['cor:dmft_dln']}
  • Proposition C.1
  • proof
  • Theorem 3
  • Theorem 4
  • Corollary 5
  • Definition E.1: Function triplet spaces $\mathcal{S}$ and $\mathcal{S}_{\mathrm{cont}}$
  • Definition E.2: Function pair spaces $\bar{\mathcal{S}}$ and $\bar{\mathcal{S}}_{\mathrm{cont}}$
  • Lemma E.3
  • proof : Proof of Lemma \ref{['lem:mapping']}
  • ...and 4 more