Table of Contents
Fetching ...

S$^{2}$FT: Efficient, Scalable and Generalizable LLM Fine-tuning by Structured Sparsity

Xinyu Yang, Jixuan Leng, Geyang Guo, Jiawei Zhao, Ryumei Nakada, Linjun Zhang, Huaxiu Yao, Beidi Chen

TL;DR

S$^{2}$FT introduces structured sparse fine-tuning for LLMs by sparsely selecting attention heads and FFN channels and densely computing within co-permuted dense submatrices, enabling end-to-end training with partial backpropagation. Theoretical and empirical results show improved generalization under distribution shifts, competitive-to-superior task performance (commonsense, arithmetic, and instruction following) compared to LoRA and even full fine-tuning in some settings, while delivering substantial training memory and latency savings. The method supports scalable serving through adapter fusion, fast switching, and parallelism, enabling efficient deployment of thousands of fine-tuned models. Overall, S$^{2}$FT closes the gap among high accuracy, training efficiency, and scalable serving, making large-scale fine-tuning more practical and robust across domains.

Abstract

Current PEFT methods for LLMs can achieve either high quality, efficient training, or scalable serving, but not all three simultaneously. To address this limitation, we investigate sparse fine-tuning and observe a remarkable improvement in generalization ability. Utilizing this key insight, we propose a family of Structured Sparse Fine-Tuning (S$^{2}$FT) methods for LLMs, which concurrently achieve state-of-the-art fine-tuning performance, training efficiency, and inference scalability. S$^{2}$FT accomplishes this by "selecting sparsely and computing densely". It selects a few heads and channels in the MHA and FFN modules for each Transformer block, respectively. Next, it co-permutes weight matrices on both sides of the coupled structures in LLMs to connect the selected components in each layer into a dense submatrix. Finally, S$^{2}$FT performs in-place gradient updates on all submatrices. Through theoretical analysis and empirical results, our method prevents forgetting while simplifying optimization, delivers SOTA performance on both commonsense and arithmetic reasoning with 4.6% and 1.3% average improvements compared to LoRA, and surpasses full FT by 11.5% when generalizing to various domains after instruction tuning. Using our partial backpropagation algorithm, S$^{2}$FT saves training memory up to 3$\times$ and improves latency by 1.5-2.7$\times$ compared to full FT, while delivering an average 10% improvement over LoRA on both metrics. We further demonstrate that the weight updates in S$^{2}$FT can be decoupled into adapters, enabling effective fusion, fast switch, and efficient parallelism for serving multiple fine-tuned models.

S$^{2}$FT: Efficient, Scalable and Generalizable LLM Fine-tuning by Structured Sparsity

TL;DR

SFT introduces structured sparse fine-tuning for LLMs by sparsely selecting attention heads and FFN channels and densely computing within co-permuted dense submatrices, enabling end-to-end training with partial backpropagation. Theoretical and empirical results show improved generalization under distribution shifts, competitive-to-superior task performance (commonsense, arithmetic, and instruction following) compared to LoRA and even full fine-tuning in some settings, while delivering substantial training memory and latency savings. The method supports scalable serving through adapter fusion, fast switching, and parallelism, enabling efficient deployment of thousands of fine-tuned models. Overall, SFT closes the gap among high accuracy, training efficiency, and scalable serving, making large-scale fine-tuning more practical and robust across domains.

Abstract

Current PEFT methods for LLMs can achieve either high quality, efficient training, or scalable serving, but not all three simultaneously. To address this limitation, we investigate sparse fine-tuning and observe a remarkable improvement in generalization ability. Utilizing this key insight, we propose a family of Structured Sparse Fine-Tuning (SFT) methods for LLMs, which concurrently achieve state-of-the-art fine-tuning performance, training efficiency, and inference scalability. SFT accomplishes this by "selecting sparsely and computing densely". It selects a few heads and channels in the MHA and FFN modules for each Transformer block, respectively. Next, it co-permutes weight matrices on both sides of the coupled structures in LLMs to connect the selected components in each layer into a dense submatrix. Finally, SFT performs in-place gradient updates on all submatrices. Through theoretical analysis and empirical results, our method prevents forgetting while simplifying optimization, delivers SOTA performance on both commonsense and arithmetic reasoning with 4.6% and 1.3% average improvements compared to LoRA, and surpasses full FT by 11.5% when generalizing to various domains after instruction tuning. Using our partial backpropagation algorithm, SFT saves training memory up to 3 and improves latency by 1.5-2.7 compared to full FT, while delivering an average 10% improvement over LoRA on both metrics. We further demonstrate that the weight updates in SFT can be decoupled into adapters, enabling effective fusion, fast switch, and efficient parallelism for serving multiple fine-tuned models.

Paper Structure

This paper contains 52 sections, 15 theorems, 111 equations, 6 figures, 6 tables.

Key Result

Theorem 4.2

Suppose Assumption asm: distribution shift holds. Consider $n \to \infty$. If $B^{(\textnormal{i})} = \overline{W}^\textnormal{pre}_{\ell+1} \tilde{B}^{(\textnormal{i})} \underline{W}^\textnormal{pre}_{\ell-1}$ holds for some $\tilde{B}^{(\textnormal{i})} \in \mathbb{R}^{d_\ell \times d_{\ell-1}}$,

Figures (6)

  • Figure 1: An Overview of the S$^2$FT Family for LLMs: First, we perform sparse selection of specific attention heads and channels within the coupled structures of the MHA and FFN modules. Next, we apply co-permutation to the weight matrices on both sides of these structures, enabling dense gradient computation only for the selected components. While we demonstrate S$^2$FT by selecting the same heads/channels on both sides for clarity, our approach also supports asymmetric selection strategies.
  • Figure 2: Accuracy comparison of SpFT, LoRA and Full FT at varying ratios of trainable parameters in various settings. SpFT exhibits strong generalization ability while full FT excels in memorization.
  • Figure 3: Grouped model weights with basic structure and residual structure. All highlighted weights must be permuted simultaneously. Residual structures require additional permutation during runtime.
  • Figure 4: The impact of different components in fine-tuning, including Query, Key, Value, Output, Up, Gate, and Down projection. We fix the trainable parameter budget and only fine-tune one component.
  • Figure 5: Comparison of memory and computation efficiency during training on the LLaMA2-7B/13B with varying sequence lengths and batch sizes. Average latency and peak memory usage are reported. S$^2$FT significantly improves training latency while reducing memory footprint compared to baselines.
  • ...and 1 more figures

Theorems & Definitions (27)

  • Theorem 4.2: Out-of-distribution Excess Risk, Informal
  • Theorem F.7
  • Theorem F.8: Restatement of Theorem \ref{['thm: sft and lora ood informal']}
  • proof : Intuition of the proof of Theorem \ref{['thm: sft and lora ood']}
  • Lemma F.9: Excess Risk
  • proof : Proof of Lemma \ref{['lem: excess risk lora']}
  • Theorem F.10: Restatement of Theorem \ref{['thm: sft and lora id']}: LoRA Part
  • proof : Proof of Theorem \ref{['thm: lora in-distribution']}
  • Theorem F.11: Restatement of Theorem \ref{['thm: sft and lora ood']}: LoRA Part
  • proof : Proof of Theorem \ref{['thm: lora ood']}
  • ...and 17 more