Table of Contents
Fetching ...

Adaptive Adversarial Cross-Entropy Loss for Sharpness-Aware Minimization

Tanapat Ratchatorn, Masayuki Tanaka

TL;DR

The paper tackles generalization in highly overparameterized models by refining the perturbation step in Sharpness-Aware Minimization (SAM). It introduces Adaptive Adversarial Cross-Entropy (AACE), whose loss $L_{\mathrm{AACE}}$ and its gradient grow as training converges, and pairs this with a non-normalized perturbation to maintain exploration and stabilize directions near convergence. Empirical results on Wide ResNet and PyramidNet across CIFAR-family datasets show improved validation performance and generalization compared to standard SAM and SGD. The approach yields more robust perturbations in late training stages and includes reproduction code for verification.

Abstract

Recent advancements in learning algorithms have demonstrated that the sharpness of the loss surface is an effective measure for improving the generalization gap. Building upon this concept, Sharpness-Aware Minimization (SAM) was proposed to enhance model generalization and achieved state-of-the-art performance. SAM consists of two main steps, the weight perturbation step and the weight updating step. However, the perturbation in SAM is determined by only the gradient of the training loss, or cross-entropy loss. As the model approaches a stationary point, this gradient becomes small and oscillates, leading to inconsistent perturbation directions and also has a chance of diminishing the gradient. Our research introduces an innovative approach to further enhancing model generalization. We propose the Adaptive Adversarial Cross-Entropy (AACE) loss function to replace standard cross-entropy loss for SAM's perturbation. AACE loss and its gradient uniquely increase as the model nears convergence, ensuring consistent perturbation direction and addressing the gradient diminishing issue. Additionally, a novel perturbation-generating function utilizing AACE loss without normalization is proposed, enhancing the model's exploratory capabilities in near-optimum stages. Empirical testing confirms the effectiveness of AACE, with experiments demonstrating improved performance in image classification tasks using Wide ResNet and PyramidNet across various datasets. The reproduction code is available online

Adaptive Adversarial Cross-Entropy Loss for Sharpness-Aware Minimization

TL;DR

The paper tackles generalization in highly overparameterized models by refining the perturbation step in Sharpness-Aware Minimization (SAM). It introduces Adaptive Adversarial Cross-Entropy (AACE), whose loss and its gradient grow as training converges, and pairs this with a non-normalized perturbation to maintain exploration and stabilize directions near convergence. Empirical results on Wide ResNet and PyramidNet across CIFAR-family datasets show improved validation performance and generalization compared to standard SAM and SGD. The approach yields more robust perturbations in late training stages and includes reproduction code for verification.

Abstract

Recent advancements in learning algorithms have demonstrated that the sharpness of the loss surface is an effective measure for improving the generalization gap. Building upon this concept, Sharpness-Aware Minimization (SAM) was proposed to enhance model generalization and achieved state-of-the-art performance. SAM consists of two main steps, the weight perturbation step and the weight updating step. However, the perturbation in SAM is determined by only the gradient of the training loss, or cross-entropy loss. As the model approaches a stationary point, this gradient becomes small and oscillates, leading to inconsistent perturbation directions and also has a chance of diminishing the gradient. Our research introduces an innovative approach to further enhancing model generalization. We propose the Adaptive Adversarial Cross-Entropy (AACE) loss function to replace standard cross-entropy loss for SAM's perturbation. AACE loss and its gradient uniquely increase as the model nears convergence, ensuring consistent perturbation direction and addressing the gradient diminishing issue. Additionally, a novel perturbation-generating function utilizing AACE loss without normalization is proposed, enhancing the model's exploratory capabilities in near-optimum stages. Empirical testing confirms the effectiveness of AACE, with experiments demonstrating improved performance in image classification tasks using Wide ResNet and PyramidNet across various datasets. The reproduction code is available online
Paper Structure (10 sections, 17 equations, 6 figures, 3 tables)

This paper contains 10 sections, 17 equations, 6 figures, 3 tables.

Figures (6)

  • Figure 1: Comparison of loss and gradient between standard cross-entropy loss and Adaptive Adversarial Cross-Entropy loss at early stage ($w_{i}$) and later stage ($w_{t}$) of training.
  • Figure 2: Probability distributions and trend patterns of standard cross-entropy loss and Adaptive Adversarial Cross-Entropy loss.
  • Figure 3: Diagram illustrates the perturbation step and the updating step of original SAM and our proposed method.
  • Figure 4: Losses comparison of standard SAM and SAM with AACE. Each data point is the average loss in the epoch.
  • Figure 5: Comparison of magnitudes of perturbation loss’s gradients and perturbation distances between SAM and our method. Note that each data point represents the average value of samples in the epoch.
  • ...and 1 more figures