Table of Contents
Fetching ...

Improving Generalization and Convergence by Enhancing Implicit Regularization

Mingze Wang, Jinbo Wang, Haotian He, Zilin Wang, Guanhua Huang, Feiyu Xiong, Zhiyu Li, Weinan E, Lei Wu

TL;DR

This work decouples the dynamics of flat and sharp directions, which boosts the sharpness reduction along flat directions while maintaining the training stability in sharp directions, and shows that IRE can be practically incorporated with {\em generic base optimizers} without introducing significant computational overload.

Abstract

In this work, we propose an Implicit Regularization Enhancement (IRE) framework to accelerate the discovery of flat solutions in deep learning, thereby improving generalization and convergence. Specifically, IRE decouples the dynamics of flat and sharp directions, which boosts the sharpness reduction along flat directions while maintaining the training stability in sharp directions. We show that IRE can be practically incorporated with {\em generic base optimizers} without introducing significant computational overload. Experiments show that IRE consistently improves the generalization performance for image classification tasks across a variety of benchmark datasets (CIFAR-10/100, ImageNet) and models (ResNets and ViTs). Surprisingly, IRE also achieves a $2\times$ {\em speed-up} compared to AdamW in the pre-training of Llama models (of sizes ranging from 60M to 229M) on datasets including Wikitext-103, Minipile, and Openwebtext. Moreover, we provide theoretical guarantees, showing that IRE can substantially accelerate the convergence towards flat minima in Sharpness-aware Minimization (SAM).

Improving Generalization and Convergence by Enhancing Implicit Regularization

TL;DR

This work decouples the dynamics of flat and sharp directions, which boosts the sharpness reduction along flat directions while maintaining the training stability in sharp directions, and shows that IRE can be practically incorporated with {\em generic base optimizers} without introducing significant computational overload.

Abstract

In this work, we propose an Implicit Regularization Enhancement (IRE) framework to accelerate the discovery of flat solutions in deep learning, thereby improving generalization and convergence. Specifically, IRE decouples the dynamics of flat and sharp directions, which boosts the sharpness reduction along flat directions while maintaining the training stability in sharp directions. We show that IRE can be practically incorporated with {\em generic base optimizers} without introducing significant computational overload. Experiments show that IRE consistently improves the generalization performance for image classification tasks across a variety of benchmark datasets (CIFAR-10/100, ImageNet) and models (ResNets and ViTs). Surprisingly, IRE also achieves a {\em speed-up} compared to AdamW in the pre-training of Llama models (of sizes ranging from 60M to 229M) on datasets including Wikitext-103, Minipile, and Openwebtext. Moreover, we provide theoretical guarantees, showing that IRE can substantially accelerate the convergence towards flat minima in Sharpness-aware Minimization (SAM).
Paper Structure (41 sections, 20 theorems, 80 equations, 6 figures, 8 tables, 1 algorithm)

This paper contains 41 sections, 20 theorems, 80 equations, 6 figures, 8 tables, 1 algorithm.

Key Result

Theorem 5.5

Suppose Assumption ass: minima manifold holds. If $\eta=\mathcal{O}(1)$ and $\rho=\mathcal{O}(\sqrt{\eta})$ in SAM equ: SAM, average, $\kappa\leqslant 1/\rho$, and $\mathcal{P}_t=P_{m+1:p}(\nabla^2\mathcal{L}(\bm{\theta}_t))$ in IRE eqn: ire-genetric, then with high probability at least $1-T_{\rm II

Figures (6)

  • Figure 1: A $2$-d example of \ref{['equ: toy']}: $\mathcal{L}(u,v)=(1+u^2)v^2/2$. The gray arrows denote to the minima manifold $\mathcal{M}=\{(u,v): v=0\}$, where the smaller the $|u|$, the flatter the minimizer. The red marker highlights the flattest minimizer $(0,0)$. (a) The dynamics of GD ($\eta=1$), which moves slowly towards flatter minima as it converges. (b) The dynamics of GD ($\eta=2$), which diverges due to the excessively large $\eta$. (c) The behavior of our IRE approach with varying $\kappa$'s v.s. GD ($\eta=1$). Is is shown that IRE can significantly accelerate the $u_t$'s dynamics, almost reaching the flattest minimum $(0,0)$ when taking a very large $\kappa$.
  • Figure 2: Training WRN-16-8 on CIFAR-10 by SAM-IRE with varying $\gamma,\kappa$. Particularly, the case of $\kappa=0$ correspond to the standard SAM.
  • Figure 3: Transformer on wikitext-2.
  • Figure 4: AdmIRE outperforms AdamW in the pre-training of Llama models.
  • Figure 5: The results for tuning lr_max in AdamW.
  • ...and 1 more figures

Theorems & Definitions (30)

  • Example 2.1
  • Remark 2.2: The generality
  • Definition 5.2: Limiting map of GF
  • Definition 5.3: Attraction set of $\mathcal{M}$
  • Remark 5.4: The mechanism of IRE's success
  • Theorem 5.5: IRE on average SAM
  • Theorem 5.6: IRE on standard SAM
  • Lemma D.1: arora2022understanding, Lemma B.2
  • Lemma D.2: Key properties of $\Phi(\cdot)$ arora2022understanding
  • Lemma D.3: Continuity of $P_{m+1:p}$
  • ...and 20 more