Table of Contents
Fetching ...

On the Learning Dynamics of Two-layer Linear Networks with Label Noise SGD

Tongcheng Zhang, Zhanpeng Zhou, Mingze Wang, Andi Han, Wei Huang, Taiji Suzuki, Junchi Yan

TL;DR

This work dives into the underlying mechanisms behind stochastic gradient descent (SGD) with label noise, and extends these insights to Sharpness-Aware Minimization (SAM), showing that the principles governing label noise SGD also apply to broader optimization algorithms.

Abstract

One crucial factor behind the success of deep learning lies in the implicit bias induced by noise inherent in gradient-based training algorithms. Motivated by empirical observations that training with noisy labels improves model generalization, we delve into the underlying mechanisms behind stochastic gradient descent (SGD) with label noise. Focusing on a two-layer over-parameterized linear network, we analyze the learning dynamics of label noise SGD, unveiling a two-phase learning behavior. In \emph{Phase I}, the magnitudes of model weights progressively diminish, and the model escapes the lazy regime; enters the rich regime. In \emph{Phase II}, the alignment between model weights and the ground-truth interpolator increases, and the model eventually converges. Our analysis highlights the critical role of label noise in driving the transition from the lazy to the rich regime and minimally explains its empirical success. Furthermore, we extend these insights to Sharpness-Aware Minimization (SAM), showing that the principles governing label noise SGD also apply to broader optimization algorithms. Extensive experiments, conducted under both synthetic and real-world setups, strongly support our theory. Our code is released at https://github.com/a-usually/Label-Noise-SGD.

On the Learning Dynamics of Two-layer Linear Networks with Label Noise SGD

TL;DR

This work dives into the underlying mechanisms behind stochastic gradient descent (SGD) with label noise, and extends these insights to Sharpness-Aware Minimization (SAM), showing that the principles governing label noise SGD also apply to broader optimization algorithms.

Abstract

One crucial factor behind the success of deep learning lies in the implicit bias induced by noise inherent in gradient-based training algorithms. Motivated by empirical observations that training with noisy labels improves model generalization, we delve into the underlying mechanisms behind stochastic gradient descent (SGD) with label noise. Focusing on a two-layer over-parameterized linear network, we analyze the learning dynamics of label noise SGD, unveiling a two-phase learning behavior. In \emph{Phase I}, the magnitudes of model weights progressively diminish, and the model escapes the lazy regime; enters the rich regime. In \emph{Phase II}, the alignment between model weights and the ground-truth interpolator increases, and the model eventually converges. Our analysis highlights the critical role of label noise in driving the transition from the lazy to the rich regime and minimally explains its empirical success. Furthermore, we extend these insights to Sharpness-Aware Minimization (SAM), showing that the principles governing label noise SGD also apply to broader optimization algorithms. Extensive experiments, conducted under both synthetic and real-world setups, strongly support our theory. Our code is released at https://github.com/a-usually/Label-Noise-SGD.
Paper Structure (24 sections, 19 theorems, 123 equations, 6 figures, 1 table, 1 algorithm)

This paper contains 24 sections, 19 theorems, 123 equations, 6 figures, 1 table, 1 algorithm.

Key Result

Theorem 4.2

Suppose cond:main (A1-2, 4-6) hold and consider the update rule in passage:theta update. With probability at least $1-O(\frac{1}{m})$, all the neurons $\boldsymbol{w}_{i}$ ($i\in[m]$) escape from the lazy regime at time $T_1 = \frac{384\sqrt{\log m}}{\sigma^2\eta^2\sqrt{m}}$.

Figures (6)

  • Figure 1: (Left).Label noise SGD (\ref{['alg:label_noise_sgd']}) leads to better generalization. Test loss $\mathcal{L}(\boldsymbol{\theta}(t))$ and accuracy ${\operatorname{Acc}}(\boldsymbol{\theta}(t))$ vs. training epochs $t$. (Right). Label noise SGD leads to sparser solutions. Testing accuracy of pruned model ${\operatorname{Acc}}(\boldsymbol{\theta}(T; \alpha))$ vs. the percentage of remaining parameters $\alpha$. Here, $\boldsymbol{\theta}(T; \alpha)$ represents the pruned model derived from the pretrained model $\boldsymbol{\theta}(T)$, with $\alpha$% of parameters remaining. We use both vanilla SGD and label noise SGD to train the models, with no weight decay or momentum. The learning rate is set to $0.1$, and the total number of epochs is $160$. Exponential moving average is employed to smooth the test accuracy curves. Results are presented for ResNet-18 kaiming2016residual trained on CIFAR-10 krizhevsky2009learning, across different label noise probabilities, $(\tau \in \{0.05, 0.1, 0.2\})$. As shown in figure (a) and figure (b), label noise SGD consistently outperforms vanilla SGD in both test loss and accuracy across different values of the label flipping probability $\tau$, providing an around $1.5$% improvement in test accuracy. As shown in \ref{['fig:label_noise_sgd_generalization']}(c), models trained with label noise SGD maintain higher performance at the same sparsity level compared to those trained with vanilla SGD.
  • Figure 2: Two-phase dynamics of label noise SGD under synthetic setup. We replicate the synthetic problem setup from \ref{['sec:exp_validate']}. (a) Loss curves. Training $\mathcal{L}_{\mathcal{D}_{\rm train}}(\boldsymbol{\theta}(t))$ and test loss $\mathcal{L}_{\mathcal{D}_{\rm test}}(\boldsymbol{\theta}(t))$ vs. training iteration $t$. (b) Learning dynamics on average. The mean neuron norm $\text{Avg}_{i\in [m]}(\Vert \boldsymbol{w}_i (t) \Vert_2)$ and the mean neuron alignment $\text{Avg}_{i\in [m]}(\langle \boldsymbol{w}_i(t), \boldsymbol{\theta}^{\star} \rangle)$ vs. training iteration $t$. (c) Learning dynamics of $i$-th neuron. The alignment of the $i$-th neuron $\langle \boldsymbol{w}_i(t), \boldsymbol{\theta}^{\star} \rangle$ vs. its weight norm $\Vert \boldsymbol{w}_i (t) \Vert_2$, with darker points indicating larger $t$. (Bottom) Complete view of dynamics of each neuron. This plot is similar to (c); yet instead of focusing on a single neuron, we plot the status over iterations.
  • Figure 3: Label noise SGD induces the rich regime. (a, b). Training $\mathcal{L}_{\mathcal{D}_{\rm train}}(\boldsymbol{\theta}(t))$ and test loss $\mathcal{L}_{\mathcal{D}_{\rm test}}(\boldsymbol{\theta}(t))$ vs. training epochs $t$. Label noise SGD induces the progressively diminishing phenomenon. (c). The first-layer weight norm $\Vert \boldsymbol{W}(t) \Vert_F$ vs. training epochs $t$. We use GD to train the models with NTK parameterization jacot2018NTK, both with and without label noise. We also train a linearized model with GD as baseline. Results are presented for WideResNets trained on a random subset of 64 images from CIFAR-10 due to the $O(n^2)$ computational complexity of NTK.
  • Figure 4: Two-phase dynamics of SAM under synthetic setup. We replicate the synthetic problem setup from \ref{['sec:exp_validate']}, replacing label noise SGD with SAM. (a) Loss curves. Training $\mathcal{L}_{\mathcal{D}_{\rm train}}(\boldsymbol{\theta}(t))$ and test loss $\mathcal{L}_{\mathcal{D}_{\rm test}}(\boldsymbol{\theta}(t))$ vs. training iteration $t$. (b) Learning dynamics on average. The averaged neuron norm $\text{Avg}_{i\in [m]}(\Vert \boldsymbol{w}_i (t) \Vert_2)$ and the averaged neuron alignment $\text{Avg}_{i\in [m]}(\langle \boldsymbol{w}_i(t), \boldsymbol{\theta}^{\star} \rangle)$ vs. training iteration $t$. (c) Learning dynamics of $i$-th neuron. The alignment of $i$-th neuron $\langle \boldsymbol{w}_i(t), \boldsymbol{\theta}^{\star} \rangle$ vs. its weight norm $\Vert \boldsymbol{w}_i (t) \Vert_2$, with darker points indicating larger iteration $t$.
  • Figure 5: SAM induces the rich regime. Training $\mathcal{L}_{\mathcal{D}_{\rm train}}(\boldsymbol{\theta}(t))$ and test loss $\mathcal{L}_{\mathcal{D}_{\rm test}}(\boldsymbol{\theta}(t))$ vs. training epochs $t$. We use both GD and full-batch SAM to train the models with NTK parameterization jacot2018NTK. We also train a linearized model with GD as baseline. Results are presented for WideResNets trained on a random subset of 64 images from CIFAR-10 due to the $O(n^2)$ computational complexity of NTK.
  • ...and 1 more figures

Theorems & Definitions (21)

  • Definition 4.1: The lazy regime
  • Theorem 4.2: Escaping the lazy regime
  • Lemma 4.3: Progressively diminishing at each step
  • Lemma 4.4: Progressively diminishing under simulation setup
  • Lemma 4.5: Alignment
  • Lemma 4.6: Convergence
  • Definition A.1
  • Lemma A.2
  • Corollary A.3
  • Lemma A.4
  • ...and 11 more