Table of Contents
Fetching ...

Layer-wise Adaptive Gradient Norm Penalizing Method for Efficient and Accurate Deep Learning

Sunwoo Lee

TL;DR

This paper tackles the high computational cost of sharpness-aware minimization (SAM) by introducing a layer-wise adaptive gradient norm penalty that selectively perturbs only a subset of layers. By identifying and perturbing the layers with the largest gradient norms, the method suppresses the overall gradient norm with far less cost than perturbing the entire network, while preserving or even improving generalization. The authors provide a convergence guarantee for partial perturbations and validate the approach empirically on computer vision benchmarks and language-model fine-tuning, achieving accuracy comparable to full SAM with throughput near that of SGD. The results suggest a practical pathway to deploy SAM-like generalization benefits in real-world, large-scale DL training, with a tunable trade-off between cost and performance.

Abstract

Sharpness-aware minimization (SAM) is known to improve the generalization performance of neural networks. However, it is not widely used in real-world applications yet due to its expensive model perturbation cost. A few variants of SAM have been proposed to tackle such an issue, but they commonly do not alleviate the cost noticeably. In this paper, we propose a lightweight layer-wise gradient norm penalizing method that tackles the expensive computational cost of SAM while maintaining its superior generalization performance. Our study empirically proves that the gradient norm of the whole model can be effectively suppressed by penalizing the gradient norm of only a few critical layers. We also theoretically show that such a partial model perturbation does not harm the convergence rate of SAM, allowing them to be safely adapted in real-world applications. To demonstrate the efficacy of the proposed method, we perform extensive experiments comparing the proposed method to mini-batch SGD and the conventional SAM using representative computer vision and language modeling benchmarks.

Layer-wise Adaptive Gradient Norm Penalizing Method for Efficient and Accurate Deep Learning

TL;DR

This paper tackles the high computational cost of sharpness-aware minimization (SAM) by introducing a layer-wise adaptive gradient norm penalty that selectively perturbs only a subset of layers. By identifying and perturbing the layers with the largest gradient norms, the method suppresses the overall gradient norm with far less cost than perturbing the entire network, while preserving or even improving generalization. The authors provide a convergence guarantee for partial perturbations and validate the approach empirically on computer vision benchmarks and language-model fine-tuning, achieving accuracy comparable to full SAM with throughput near that of SGD. The results suggest a practical pathway to deploy SAM-like generalization benefits in real-world, large-scale DL training, with a tunable trade-off between cost and performance.

Abstract

Sharpness-aware minimization (SAM) is known to improve the generalization performance of neural networks. However, it is not widely used in real-world applications yet due to its expensive model perturbation cost. A few variants of SAM have been proposed to tackle such an issue, but they commonly do not alleviate the cost noticeably. In this paper, we propose a lightweight layer-wise gradient norm penalizing method that tackles the expensive computational cost of SAM while maintaining its superior generalization performance. Our study empirically proves that the gradient norm of the whole model can be effectively suppressed by penalizing the gradient norm of only a few critical layers. We also theoretically show that such a partial model perturbation does not harm the convergence rate of SAM, allowing them to be safely adapted in real-world applications. To demonstrate the efficacy of the proposed method, we perform extensive experiments comparing the proposed method to mini-batch SGD and the conventional SAM using representative computer vision and language modeling benchmarks.

Paper Structure

This paper contains 16 sections, 5 theorems, 15 equations, 6 figures, 4 tables, 1 algorithm.

Key Result

Lemma A.1

Given a $\beta$-smooth loss function $L(x)$, we have the following bound for any $x \in \mathbb{R}^d$.

Figures (6)

  • Figure 1: The layer-wise gradient norm curves of a) CIFAR-10 (ResNet20) training and b) Oxford_flowers102 (Wide-ResNet16). All the layers show consistent gradient norms throughout the training epochs.
  • Figure 2: The loss landscape of Wide-ResNet16 trained on Oxford_Flowers102. Each layer has noticeably different loss landscapes.
  • Figure 3: The learning curves of CIFAR-10, CIFAR-100, and Oxford_Flowers102. The hyper-parameters are shown in Table \ref{['tab:compare']}.
  • Figure 4: The full model gradient norm curves of CIFAR-10 (ResNet20), CIFAR-100 (Wide-ResNet28), and Oxford_Flowers102 (Wide-ResNet16). The norm is noticeably reduced when any SAM method is applied. The proposed layer-wise method effectively suppresses the gradient norm likely to the full model perturbation method.
  • Figure 5: The number of perturbations at all individual layers. The generalized SAM perturbs all the layers at every iteration. LookSAM's gradient ascent re-calculation interval is $3$, $2$, and $3$ for the three benchmarks, respectively. Our proposed method selectively perturbs the $k$ critical layers only, and the $k$ is set to $4$, $4$, and $2$ for the three benchmarks, respectively.
  • ...and 1 more figures

Theorems & Definitions (5)

  • Lemma A.1
  • Lemma A.2
  • Lemma A.3
  • Lemma A.4
  • Theorem A.5