Table of Contents
Fetching ...

Training Diagonal Linear Networks with Stochastic Sharpness-Aware Minimization

Gabriel Clara, Sophie Langer, Johannes Schmidt-Hieber

TL;DR

This work analyzes stochastic sharpness-aware minimization (S-SAM) for diagonal linear networks in a linear regression setting. By marginalizing over Gaussian perturbations, the authors derive a regularized loss L_R where the average-sharpness penalty enforces a layerwise balancing constraint W_ell^2 = W_{ell+1}^2 and induces shrinkage of the true parameter through shrinkage factors that solve a depth- and noise-dependent equation. They prove that gradient flow and gradient descent on L_R drive the network toward balanced, sharpness-minimizing stationary points and establish convergence for the projected stochastic recursion, connecting the dynamics to generalization via PAC-Bayes-like bounds. The results provide a principled theoretical lens on how algorithmic noise regularizes training in a tractable toy model and offer avenues for extending the insights to more complex linear and nonlinear networks.

Abstract

We analyze the landscape and training dynamics of diagonal linear networks in a linear regression task, with the network parameters being perturbed by small isotropic normal noise. The addition of such noise may be interpreted as a stochastic form of sharpness-aware minimization (SAM) and we prove several results that relate its action on the underlying landscape and training dynamics to the sharpness of the loss. In particular, the noise changes the expected gradient to force balancing of the weight matrices at a fast rate along the descent trajectory. In the diagonal linear model, we show that this equates to minimizing the average sharpness, as well as the trace of the Hessian matrix, among all possible factorizations of the same matrix. Further, the noise forces the gradient descent iterates towards a shrinkage-thresholding of the underlying true parameter, with the noise level explicitly regulating both the shrinkage factor and the threshold.

Training Diagonal Linear Networks with Stochastic Sharpness-Aware Minimization

TL;DR

This work analyzes stochastic sharpness-aware minimization (S-SAM) for diagonal linear networks in a linear regression setting. By marginalizing over Gaussian perturbations, the authors derive a regularized loss L_R where the average-sharpness penalty enforces a layerwise balancing constraint W_ell^2 = W_{ell+1}^2 and induces shrinkage of the true parameter through shrinkage factors that solve a depth- and noise-dependent equation. They prove that gradient flow and gradient descent on L_R drive the network toward balanced, sharpness-minimizing stationary points and establish convergence for the projected stochastic recursion, connecting the dynamics to generalization via PAC-Bayes-like bounds. The results provide a principled theoretical lens on how algorithmic noise regularizes training in a tractable toy model and offer avenues for extending the insights to more complex linear and nonlinear networks.

Abstract

We analyze the landscape and training dynamics of diagonal linear networks in a linear regression task, with the network parameters being perturbed by small isotropic normal noise. The addition of such noise may be interpreted as a stochastic form of sharpness-aware minimization (SAM) and we prove several results that relate its action on the underlying landscape and training dynamics to the sharpness of the loss. In particular, the noise changes the expected gradient to force balancing of the weight matrices at a fast rate along the descent trajectory. In the diagonal linear model, we show that this equates to minimizing the average sharpness, as well as the trace of the Hessian matrix, among all possible factorizations of the same matrix. Further, the noise forces the gradient descent iterates towards a shrinkage-thresholding of the underlying true parameter, with the noise level explicitly regulating both the shrinkage factor and the threshold.

Paper Structure

This paper contains 29 sections, 17 theorems, 138 equations, 3 figures.

Key Result

Lemma 2.1

In the linear diagonal model, for every $\ell = 1, \ldots, L$

Figures (3)

  • Figure 1: A diagonal linear network with $L$ layers, implementing the function $f(x_1, \ldots, x_d) = \sum_{i = 1}^d w_{L, i} \cdots w_{1, i} x_i$.
  • Figure 2: Contour plots of the loss $(w_{*} - w_2 w_1)^2 + \eta^2 (w_1^2 + w_2^2)$ with $w_{*} = 3.14159$ and different values of $\eta$. Large values of $\eta$ cause increasing shrinkage of the critical points towards zero.
  • Figure 3: Gradient descent trajectories overlaid onto a contour plot of the loss $(w_{*} - w_2 w_1)^2$ with $w_{*} = 3.14159$, started from the same initial point and run with different values of $\eta$ and $\alpha_k$. Shown are standard gradient descent following $- \nabla \mathcal{L}$ (yellow), explicitly regularized gradient descent following $- \nabla \mathcal{L}_R$ (green), and gradient descent with S-SAM following $- \nabla \widetilde{\mathcal{L}}_k$ (blue).

Theorems & Definitions (24)

  • Lemma 2.1
  • Lemma 2.2
  • Lemma 2.3
  • Theorem 3.1
  • Lemma 3.2
  • Theorem 4.1
  • Theorem 4.2
  • Theorem 4.3
  • Theorem 4.4
  • Lemma C.1
  • ...and 14 more