Table of Contents
Fetching ...

Boosting Large Language Models with Mask Fine-Tuning

Mingyuan Zhang, Yue Bai, Huan Wang, Yizhou Wang, Qihua Dong, Yun Fu

TL;DR

The paper questions whether preserving full LLM structural integrity is necessary during fine-tuning and introduces Mask Fine-Tuning (MFT), which freezes a well-tuned model and learns a binary mask to remove certain weights. Guided by standard fine-tuning objectives, MFT employs a straight-through estimator to optimize the mask, achieving performance gains that exceed the FFT upper bound across multiple backbones and domains. Through extensive experiments and analyses, including layer-group ablations, masking-ratio studies, and data-ratio investigations, MFT demonstrates consistent improvements and provides a new, practical post-finetuning protocol. By extending mask learning beyond pruning, the work offers a general approach to enhance LLM performance within existing fine-tuning pipelines and highlights the potential of sparsity-driven augmentation in large-scale models.

Abstract

The model is usually kept integral in the mainstream large language model (LLM) fine-tuning protocols. No works have questioned whether maintaining the integrity of the model is indispensable for performance. In this work, we introduce Mask Fine-Tuning (MFT), a brand-new LLM fine-tuning paradigm to show that properly breaking the integrity of the model can surprisingly lead to improved performance. Specifically, MFT learns a set of binary masks supervised by the typical LLM fine-tuning objective. Extensive experiments show that MFT gains a consistent performance boost across various domains and backbones (e.g., 1.95%/1.88% average gain in coding with LLaMA2-7B/3.1-8B). Detailed procedures are provided to study the proposed MFT from different hyperparameter perspectives for better insight. In particular, MFT naturally updates the current LLM training protocol by deploying it on a complete well-trained model. This study extends the functionality of mask learning from its conventional network pruning context for model compression to a more general scope.

Boosting Large Language Models with Mask Fine-Tuning

TL;DR

The paper questions whether preserving full LLM structural integrity is necessary during fine-tuning and introduces Mask Fine-Tuning (MFT), which freezes a well-tuned model and learns a binary mask to remove certain weights. Guided by standard fine-tuning objectives, MFT employs a straight-through estimator to optimize the mask, achieving performance gains that exceed the FFT upper bound across multiple backbones and domains. Through extensive experiments and analyses, including layer-group ablations, masking-ratio studies, and data-ratio investigations, MFT demonstrates consistent improvements and provides a new, practical post-finetuning protocol. By extending mask learning beyond pruning, the work offers a general approach to enhance LLM performance within existing fine-tuning pipelines and highlights the potential of sparsity-driven augmentation in large-scale models.

Abstract

The model is usually kept integral in the mainstream large language model (LLM) fine-tuning protocols. No works have questioned whether maintaining the integrity of the model is indispensable for performance. In this work, we introduce Mask Fine-Tuning (MFT), a brand-new LLM fine-tuning paradigm to show that properly breaking the integrity of the model can surprisingly lead to improved performance. Specifically, MFT learns a set of binary masks supervised by the typical LLM fine-tuning objective. Extensive experiments show that MFT gains a consistent performance boost across various domains and backbones (e.g., 1.95%/1.88% average gain in coding with LLaMA2-7B/3.1-8B). Detailed procedures are provided to study the proposed MFT from different hyperparameter perspectives for better insight. In particular, MFT naturally updates the current LLM training protocol by deploying it on a complete well-trained model. This study extends the functionality of mask learning from its conventional network pruning context for model compression to a more general scope.

Paper Structure

This paper contains 15 sections, 5 equations, 10 figures, 6 tables.

Figures (10)

  • Figure 1: Typical LLM training contains pre-training and fine-tuning for foundation capacity and domain knowledge, where the LLM structure is always kept entirely. We are curious if such integrity is necessary for good performance and propose MFT to generally outperform model with sufficient FFT. Therefore, MFT naturally upgrades the classic fine-tuning pipeline by following typical protocol to further refine well-optimized LLMs.
  • Figure 2: The visualization of performance trend along with different fine-tuning strategies, including FFT (blue line), LoRA (green line) hu2021lora, and our MFT (red line). We also add random (orange dash) and L1 mask (gray dash) for comparison. We use three settings across LLaMA2 and LLaMA3.1 backbones on GSM8K, HumanEval, and IF-Eval for math, coding, and instruction domains, respectively. The x-axis is training steps starting from the pre-trained backbone. The y-axis is evaluation performance. MFT (red line) starts from the best FFT model (yellow star) and breaks the upper bound with further improvements, while continued FFT leads to over-fitting. It also performs better than LoRA fine-tuning and two vanilla mask baselines.
  • Figure 3: We visualize the ablation study of local MFT strategy. It uses LLaMA2-7B and LLaMA3.1-8B backbones, covering math, coding, and instruction domains. In each figure, we conduct MFT ablations with 10% masking ratio on a domain-specific FFT model (black dash line). We swap the ablation under two local granularities, 8-layer (purple) and 4-layer (orange), from shallow to deep layers with 10% fine-tuning set. We find 1) MFT can outperform the FFT strong baseline and 2) MFT performs better in shallow (0-7) and relatively deep layers (20-27). Please note, this ablation is only for quick intuitions with limited performance gain of MFT using 10% subset. Based on this trend, we deploy MFT with complete fine-tuning sets, achieving more improvements (Tab. \ref{['tab:exp_llama2']} and Tab. \ref{['tab:exp_llama3']}).
  • Figure 4: Masking ratio ablation visualizations. We use coding and instruction domains on LLaMA2 and LLaMA3.1 backbones. We observe the original 10% ratio works well on coding but not the optimal one on instruction, which indicates the masking ratio matters and more improvement potential for MFT strategy.
  • Figure 5: Data ratio ablation visualizations. We use math and instruction domains on LLaMA2 and LLaMA3.1 backbones. Compared with FFT upper bound (red dash line), we observe MFT (purple) always improves the model using full dataset, but it may still obtain promising performance gain using less data.
  • ...and 5 more figures