Bilateral Sharpness-Aware Minimization for Flatter Minima
Jiaxin Deng, Junbiao Pang, Baochang Zhang, Qingming Huang
TL;DR
This work identifies a Flatness Indicator Problem in Sharpness-Aware Minimization (SAM) that arises from focusing only on MaxS. It introduces Min-Sharpness (MinS) and forms Bilateral Sharpness (BilS) to capture flatness in both gradient-ascent and gradient-descent directions, culminating in Bilateral SAM (BSAM). The authors prove convergence and demonstrate through extensive experiments across classification, transfer learning, pose estimation, and quantization that BSAM yields flatter minima and improved generalization and robustness relative to SAM and SGD. The approach establishes broader applicability and suggests that balancing perturbations on both sides of the current point better guides optimization toward flatter regions of the loss landscape.
Abstract
Sharpness-Aware Minimization (SAM) enhances generalization by reducing a Max-Sharpness (MaxS). Despite the practical success, we empirically found that the MAxS behind SAM's generalization enhancements face the "Flatness Indicator Problem" (FIP), where SAM only considers the flatness in the direction of gradient ascent, resulting in a next minimization region that is not sufficiently flat. A better Flatness Indicator (FI) would bring a better generalization of neural networks. Because SAM is a greedy search method in nature. In this paper, we propose to utilize the difference between the training loss and the minimum loss over the neighborhood surrounding the current weight, which we denote as Min-Sharpness (MinS). By merging MaxS and MinS, we created a better FI that indicates a flatter direction during the optimization. Specially, we combine this FI with SAM into the proposed Bilateral SAM (BSAM) which finds a more flatter minimum than that of SAM. The theoretical analysis proves that BSAM converges to local minima. Extensive experiments demonstrate that BSAM offers superior generalization performance and robustness compared to vanilla SAM across various tasks, i.e., classification, transfer learning, human pose estimation, and network quantization. Code is publicly available at: https://github.com/ajiaaa/BSAM.
