Table of Contents
Fetching ...

Stochastic Resetting Mitigates Latent Gradient Bias of SGD from Label Noise

Youngkyoung Bae, Yeongwoo Song, Hawoong Jeong

TL;DR

The paper tackles the problem of DNN memorization of noisy labels during SGD by introducing stochastic resetting to a checkpoint. By mapping SGD to overdamped Langevin dynamics, it identifies a latent gradient bias toward corrupted labels and shows that resetting can counteract this bias, with theoretical conditions linking stochasticity, drift, and the noise rate. Empirical results across synthetic and real-world noisy datasets demonstrate consistent generalization gains, with performance boosts especially pronounced at higher noise levels and when resetting targets the latter layers of networks. The method is simple to implement, compatible with existing noisy-label defenses, and offers a physics-inspired lens for understanding training dynamics under label noise.

Abstract

Giving up and starting over may seem wasteful in many situations such as searching for a target or training deep neural networks (DNNs). Our study, though, demonstrates that resetting from a checkpoint can significantly improve generalization performance when training DNNs with noisy labels. In the presence of noisy labels, DNNs initially learn the general patterns of the data but then gradually memorize the corrupted data, leading to overfitting. By deconstructing the dynamics of stochastic gradient descent (SGD), we identify the behavior of a latent gradient bias induced by noisy labels, which harms generalization. To mitigate this negative effect, we apply the stochastic resetting method to SGD, inspired by recent developments in the field of statistical physics achieving efficient target searches. We first theoretically identify the conditions where resetting becomes beneficial, and then we empirically validate our theory, confirming the significant improvements achieved by resetting. We further demonstrate that our method is both easy to implement and compatible with other methods for handling noisy labels. Additionally, this work offers insights into the learning dynamics of DNNs from an interpretability perspective, expanding the potential to analyze training methods through the lens of statistical physics.

Stochastic Resetting Mitigates Latent Gradient Bias of SGD from Label Noise

TL;DR

The paper tackles the problem of DNN memorization of noisy labels during SGD by introducing stochastic resetting to a checkpoint. By mapping SGD to overdamped Langevin dynamics, it identifies a latent gradient bias toward corrupted labels and shows that resetting can counteract this bias, with theoretical conditions linking stochasticity, drift, and the noise rate. Empirical results across synthetic and real-world noisy datasets demonstrate consistent generalization gains, with performance boosts especially pronounced at higher noise levels and when resetting targets the latter layers of networks. The method is simple to implement, compatible with existing noisy-label defenses, and offers a physics-inspired lens for understanding training dynamics under label noise.

Abstract

Giving up and starting over may seem wasteful in many situations such as searching for a target or training deep neural networks (DNNs). Our study, though, demonstrates that resetting from a checkpoint can significantly improve generalization performance when training DNNs with noisy labels. In the presence of noisy labels, DNNs initially learn the general patterns of the data but then gradually memorize the corrupted data, leading to overfitting. By deconstructing the dynamics of stochastic gradient descent (SGD), we identify the behavior of a latent gradient bias induced by noisy labels, which harms generalization. To mitigate this negative effect, we apply the stochastic resetting method to SGD, inspired by recent developments in the field of statistical physics achieving efficient target searches. We first theoretically identify the conditions where resetting becomes beneficial, and then we empirically validate our theory, confirming the significant improvements achieved by resetting. We further demonstrate that our method is both easy to implement and compatible with other methods for handling noisy labels. Additionally, this work offers insights into the learning dynamics of DNNs from an interpretability perspective, expanding the potential to analyze training methods through the lens of statistical physics.
Paper Structure (30 sections, 23 equations, 14 figures, 9 tables, 1 algorithm)

This paper contains 30 sections, 23 equations, 14 figures, 9 tables, 1 algorithm.

Figures (14)

  • Figure 1: (a) Schematic of stochastic gradient descent (SGD) dynamics with stochastic resetting. The network parameter vector $\bm{\theta}$ evolves via SGD to find an optimal value $\bm{\theta}^*$ on the training risk landscape $\mathcal{R}_{\tilde{\mathcal{D}}_{\rm tr}}$ (upper colormap), which differs from the true risk landscape $\mathcal{R}_{\mathcal{D}}$ (lower colormap) due to corrupted data. Here, $\bm{\theta}$ resets to the checkpoint $\bm{\theta}_c$ (home icon) with the reset probability $r$ and resets to $\bm{\theta}_c$. (b) Fraction of correctly predicted data with wrong labels during training with SGD (gray) and SGD with reset (green). The inset shows the validation losses during training.
  • Figure 2: (a) Schematic of $-\bm{\nabla}_{\bm{\theta}} \mathcal{R}_{\tilde{\mathcal{D}}_{\rm tr}}(\bm{\theta})$, decomposed by two orthogonal terms $-\bm{\nabla}_{\bm{\theta}}\hat{\mathcal{R}}_{\tilde{\mathcal{D}}_{\rm tr}^c}(\bm{\theta})$ and $-\bm{\nabla}_{\bm{\theta}}\hat{\mathcal{R}}_{\tilde{\mathcal{D}}_{\rm tr}^w}(\bm{\theta})$. (b) Cosine similarity between $-\bm{\nabla}_{\bm{\theta}} \mathcal{R}_{\tilde{\mathcal{D}}_{\rm tr}}(\bm{\theta})$ and $-\bm{\nabla}_{\bm{\theta}}\hat{\mathcal{R}}_{\tilde{\mathcal{D}}_{\rm tr}^w}(\bm{\theta})$ ($\cos{\phi_{tw}}$; red), and between $-\bm{\nabla}_{\bm{\theta}} \mathcal{R}_{\tilde{\mathcal{D}}_{\rm tr}^c}(\bm{\theta})$ and $-\bm{\nabla}_{\bm{\theta}}\hat{\mathcal{R}}_{\tilde{\mathcal{D}}_{\rm tr}^w}(\bm{\theta})$ (grey), throughout all training iterations for varying noise rate $\tau$. (c) Magnitude difference between the two vectors $\| \bm{\nabla}_{\bm{\theta}}\hat{\mathcal{R}}_{\tilde{\mathcal{D}}_{\rm tr}^w}(\bm{\theta}) \| -\| \bm{\nabla}_{\bm{\theta}}\hat{\mathcal{R}}_{\tilde{\mathcal{D}}_{\rm tr}^c}(\bm{\theta}) \|$ throughout all training iterations for varying $\tau$. Here, we set the batch size to $B=8$ in Setting $1$ described in Sec. \ref{['sec:4']}. Darker colors represent larger values of $\tau$ in (b, c).
  • Figure 3: The mean first passage time (MFPT) $\langle T(\gamma) \rangle$ from Eq. \ref{['eq:MFPT_drift']} with varying (a) diffusion coefficient $D$ and (b) drift $v$ with respect to the reset rate $\gamma$. Markers represent the minimum MFPT, $\langle T(\gamma^*)\rangle$, at the optimal reset rate $\gamma^*$. We set $v=1$ in (a), $D=1$ in (b), and $L=1$ in both.
  • Figure 4: (a) Test accuracies of the SGD (gray) and the SGD with our resetting method (green) during training. The inset shows the validation losses. (b,c) Relative difference of validation loss (RDVLoss) with varying the checkpoint to reset to with respect to the reset probability $r$. In (b), based on the checkpoint at the overfitting iteration $t_{m}$, RDVLoss is obtained in earlier iterations (left) and later iterations than $t_m$ (right). $t_m + \delta t$ denotes the iteration where the checkpoint is selected. In (c), RDVLoss is plotted with the perturbed checkpoint parameters $\bm{\theta}_{c, \epsilon} \equiv \bm{\theta}_c + \epsilon \hat{\bm{n}}$, where $\bm{\theta}_c$ denotes the checkpoint and $\hat{\bm{n}}$ denotes a random unit vector. The shaded areas denote the standard error.
  • Figure 5: Relative difference of validation loss (RDVLoss, left) and relative difference of test accuracy (RDTAcc., right) results with (a) varying the batch size $B$, and (b) varying the noise rate $\tau$ with respect to the reset probability $r$. We set $\tau = 0.4$ in (a) and $B=16$ in (b). The shaded areas denote the standard error.
  • ...and 9 more figures