Table of Contents
Fetching ...

Convergence of Sharpness-Aware Minimization Algorithms using Increasing Batch Size and Decaying Learning Rate

Hinata Harada, Hideaki Iiduka

TL;DR

The numerically compare SAM (GSAM) with and without an increasing batch size and conclude that using an increasing batch size or decaying learning rate finds flatter local minima than using a constant batch size and learning rate.

Abstract

The sharpness-aware minimization (SAM) algorithm and its variants, including gap guided SAM (GSAM), have been successful at improving the generalization capability of deep neural network models by finding flat local minima of the empirical loss in training. Meanwhile, it has been shown theoretically and practically that increasing the batch size or decaying the learning rate avoids sharp local minima of the empirical loss. In this paper, we consider the GSAM algorithm with increasing batch sizes or decaying learning rates, such as cosine annealing or linear learning rate, and theoretically show its convergence. Moreover, we numerically compare SAM (GSAM) with and without an increasing batch size and conclude that using an increasing batch size or decaying learning rate finds flatter local minima than using a constant batch size and learning rate.

Convergence of Sharpness-Aware Minimization Algorithms using Increasing Batch Size and Decaying Learning Rate

TL;DR

The numerically compare SAM (GSAM) with and without an increasing batch size and conclude that using an increasing batch size or decaying learning rate finds flatter local minima than using a constant batch size and learning rate.

Abstract

The sharpness-aware minimization (SAM) algorithm and its variants, including gap guided SAM (GSAM), have been successful at improving the generalization capability of deep neural network models by finding flat local minima of the empirical loss in training. Meanwhile, it has been shown theoretically and practically that increasing the batch size or decaying the learning rate avoids sharp local minima of the empirical loss. In this paper, we consider the GSAM algorithm with increasing batch sizes or decaying learning rates, such as cosine annealing or linear learning rate, and theoretically show its convergence. Moreover, we numerically compare SAM (GSAM) with and without an increasing batch size and conclude that using an increasing batch size or decaying learning rate finds flatter local minima than using a constant batch size and learning rate.
Paper Structure (20 sections, 13 theorems, 124 equations, 7 figures, 4 tables, 1 algorithm)

This paper contains 20 sections, 13 theorems, 124 equations, 7 figures, 4 tables, 1 algorithm.

Key Result

Theorem 2.1

Suppose that Assumption assum:1 holds and define $\bm{\omega}_t \in \mathbb{R}^d$ for all $t \in \mathbb{N} \cup \{0\}$ by $\bm{\omega}_t := \hat{\bm{\omega}}_t + \alpha \nabla f_{S_t \perp}(\bm{x}_t)$, where $\bm{x}_t$ is generated by Algorithm algo:1 and we assume that $G_{\perp} := \sup_{t\in \ma where $\mathbb{E}[\cdot]$ stands for the total expectation defined by $\mathbb{E} = \mathbb{E}_{\bm

Figures (7)

  • Figure 1: (Left) Loss function value in training and (Right) accuracy score in testing for the algorithms versus the number of epochs in training Wide-ResNet-28-10 on the CIFAR100 dataset. The learning rate of each algorithm was fixed at 0.1. In SGD/SAM/GSAM, the batch size was fixed at 128. In SGD/SAM/GSAM + increasing_batch, the batch size was set at 8 for the first 40 epochs and then it was doubled every 40 epochs afterwards, i.e., to 16 for epochs 41-80, 32 for epochs 81-120, etc.
  • Figure 2: (Left) Loss function value in training and (Right) accuracy score in testing for the algorithms versus the number of epochs in training Wide-ResNet28-10 on the CIFAR100 dataset. The batch size of each algorithm was fixed at 128. In SGD/SAM/GSAM, the constant learning rate was fixed at 0.1. In SGD/SAM/GSAM + Cosine, the maximum learning rate was 0.1 and the minimum learning rate was 0.001.
  • Figure 3: (Left) Loss function value in training and (Right) accuracy score in testing for the optimizers versus the number of epochs in training ViT-Tiny on the CIFAR100 dataset. The learning rate of each optimizer was fixed at 0.001 with an initial learning rate 0.00001 and linear warmup during 10 epochs. In Adam/SAM/GSAM, the batch size was fixed at 128. In Adam/SAM/GSAM + increasing batch, the batch size was set at 64 for the first 25 epochs and then it was doubled every 25 epochs afterwards, i.e., to 128 for epochs 26-50, 256 for epochs 51-75, etc.
  • Figure 4: (Left) Loss function value in training and (Right) accuracy score in testing for the optimizers versus the number of epochs in training ViT-Tiny on the CIFAR100 dataset. The batch size of each optimizer was fixed at 128. In Adam/SAM/GSAM, the constant learning rate was fixed at 0.001 with an initial learning rate 0.00001 and linear warmup during the first 10 epochs. In Adam/SAM/GSAM + Cosine, the maximum learning rate was 0.001 and the minimum learning rate was 0.00001 with linear warmup during the first 10 epochs.
  • Figure 5: (Left) Loss function value in training and (Right) accuracy score in testing for the optimizers versus the number of epochs in training ResNet18 on the CIFAR100 dataset. The learning rate of each optimizer was fixed at 0.1. In SGD/SAM/GSAM, the batch size was fixed at 128. In SGD/SAM/GSAM + increasing_batch, the batch size was set at 16 for the first 40 epochs and then it was doubled every 40 epochs afterwards, i.e., to 32 for epochs 41-80, 64 for epochs 81-120, 128 for epochs 120 to 160 and 256 for epochs 160 to 200).
  • ...and 2 more figures

Theorems & Definitions (13)

  • Theorem 2.1: Upper bound of $\mathbb{E}\eta_t \|\bm{\omega}_t\|_2$
  • Theorem 2.2: Lower bound of $\mathbb{E}\eta_t \|\bm{\omega}_t\|_2$
  • Theorem 2.3: $\epsilon$--approximation of GSAM with an increasing batch size and constant learning rate
  • Theorem 2.4: $\epsilon$--approximation of GSAM with a constant batch size and decaying learning rate
  • Proposition A.1
  • Proposition A.2
  • Theorem B.1: $\epsilon$--approximation of GSAM with an increasing batch size and decaying learning rate
  • Lemma B.1: Descent lemma
  • Proposition B.1
  • Proposition B.2
  • ...and 3 more