Layer-wise Adaptive Gradient Norm Penalizing Method for Efficient and Accurate Deep Learning
Sunwoo Lee
TL;DR
This paper tackles the high computational cost of sharpness-aware minimization (SAM) by introducing a layer-wise adaptive gradient norm penalty that selectively perturbs only a subset of layers. By identifying and perturbing the layers with the largest gradient norms, the method suppresses the overall gradient norm with far less cost than perturbing the entire network, while preserving or even improving generalization. The authors provide a convergence guarantee for partial perturbations and validate the approach empirically on computer vision benchmarks and language-model fine-tuning, achieving accuracy comparable to full SAM with throughput near that of SGD. The results suggest a practical pathway to deploy SAM-like generalization benefits in real-world, large-scale DL training, with a tunable trade-off between cost and performance.
Abstract
Sharpness-aware minimization (SAM) is known to improve the generalization performance of neural networks. However, it is not widely used in real-world applications yet due to its expensive model perturbation cost. A few variants of SAM have been proposed to tackle such an issue, but they commonly do not alleviate the cost noticeably. In this paper, we propose a lightweight layer-wise gradient norm penalizing method that tackles the expensive computational cost of SAM while maintaining its superior generalization performance. Our study empirically proves that the gradient norm of the whole model can be effectively suppressed by penalizing the gradient norm of only a few critical layers. We also theoretically show that such a partial model perturbation does not harm the convergence rate of SAM, allowing them to be safely adapted in real-world applications. To demonstrate the efficacy of the proposed method, we perform extensive experiments comparing the proposed method to mini-batch SGD and the conventional SAM using representative computer vision and language modeling benchmarks.
