Sparse Layer Sharpness-Aware Minimization for Efficient Fine-Tuning
Yifei Cheng, Xianglin Yang, Guoxia Wang, Chao Huang, Fei Ma, Dianhai Yu, Xiaochun Cao, Li Shen
TL;DR
This work tackles the high computational cost of Sharpness-Aware Minimization (SAM) during fine-tuning by introducing Sparse-Layer SAM (SL-SAM), which imposes adaptive layerwise sparsity and frames layer selection as a multi-armed bandit problem to choose active layers for both the perturbation and update steps. By maintaining a layer-selection distribution and updating it with an EXP3-based rule guided by gradient norms, SL-SAM achieves a two-sided sparsity that drastically reduces gradient computations while preserving SAM’s generalization benefits; it also provides a formal convergence guarantee with a rate of $\frac{1}{T}\sum_{t=1}^T \mathbb{E}\| abla f(x_t)\|_1 = \mathcal{O}(T^{-1/4})$. Empirically, SL-SAM delivers competitive or state-of-the-art performance across DeiT, RoBERTa, and Llama-3 fine-tuning tasks, while reducing memory and time costs substantially (e.g., roughly 20–25% reductions in GPU memory and epoch time). The approach extends to single-step SAM variants and demonstrates robustness via ablations, making SAM-based fine-tuning feasible for large-scale models in practice.
Abstract
Sharpness-aware minimization (SAM) seeks the minima with a flat loss landscape to improve the generalization performance in machine learning tasks, including fine-tuning. However, its extra parameter perturbation step doubles the computation cost, which becomes the bottleneck of SAM in the practical implementation. In this work, we propose an approach SL-SAM to break this bottleneck by introducing the sparse technique to layers. Our key innovation is to frame the dynamic selection of layers for both the gradient ascent (perturbation) and descent (update) steps as a multi-armed bandit problem. At the beginning of each iteration, SL-SAM samples a part of the layers of the model according to the gradient norm to participate in the backpropagation of the following parameter perturbation and update steps, thereby reducing the computation complexity. We then provide the analysis to guarantee the convergence of SL-SAM. In the experiments of fine-tuning models in several tasks, SL-SAM achieves the performances comparable to the state-of-the-art baselines, including a \#1 rank on LLM fine-tuning. Meanwhile, SL-SAM significantly reduces the ratio of active parameters in backpropagation compared to vanilla SAM (SL-SAM activates 47\%, 22\% and 21\% parameters on the vision, moderate and large language model respectively while vanilla SAM always activates 100\%), verifying the efficiency of our proposed algorithm.
