Table of Contents
Fetching ...

Mitigating Parameter Interference in Model Merging via Sharpness-Aware Fine-Tuning

Yeoreum Lee, Jinwook Jung, Sungyong Baik

TL;DR

This work tackles parameter interference in merging multiple task-specific models that originate from a single pre-trained model. It introduces Sharpness-Aware Fine-Tuning (SAFT), inspired by sharpness-aware minimization, to jointly maximize per-task performance and minimize weight disentanglement, aiming for flatter minima that tolerate merging perturbations. The authors provide both empirical and theoretical support, including weight disentanglement improvements, better cross-task linearity, and a joint-task loss linearity property, across various merging methods and backbones. The approach demonstrates consistent performance gains for merged models and offers a practical, generalizable framework for efficient multi-task model merging in real-world settings.

Abstract

Large-scale deep learning models with a pretraining-finetuning paradigm have led to a surge of numerous task-specific models fine-tuned from a common pre-trained model. Recently, several research efforts have been made on merging these large models into a single multi-task model, particularly with simple arithmetic on parameters. Such merging methodology faces a central challenge: interference between model parameters fine-tuned on different tasks. Few recent works have focused on designing a new fine-tuning scheme that can lead to small parameter interference, however at the cost of the performance of each task-specific fine-tuned model and thereby limiting that of a merged model. To improve the performance of a merged model, we note that a fine-tuning scheme should aim for (1) smaller parameter interference and (2) better performance of each fine-tuned model on the corresponding task. In this work, we aim to design a new fine-tuning objective function to work towards these two goals. In the course of this process, we find such objective function to be strikingly similar to sharpness-aware minimization (SAM) objective function, which aims to achieve generalization by finding flat minima. Drawing upon our observation, we propose to fine-tune pre-trained models via sharpness-aware minimization. The experimental and theoretical results showcase the effectiveness and orthogonality of our proposed approach, improving performance upon various merging and fine-tuning methods. Our code is available at https://github.com/baiklab/SAFT-Merge.

Mitigating Parameter Interference in Model Merging via Sharpness-Aware Fine-Tuning

TL;DR

This work tackles parameter interference in merging multiple task-specific models that originate from a single pre-trained model. It introduces Sharpness-Aware Fine-Tuning (SAFT), inspired by sharpness-aware minimization, to jointly maximize per-task performance and minimize weight disentanglement, aiming for flatter minima that tolerate merging perturbations. The authors provide both empirical and theoretical support, including weight disentanglement improvements, better cross-task linearity, and a joint-task loss linearity property, across various merging methods and backbones. The approach demonstrates consistent performance gains for merged models and offers a practical, generalizable framework for efficient multi-task model merging in real-world settings.

Abstract

Large-scale deep learning models with a pretraining-finetuning paradigm have led to a surge of numerous task-specific models fine-tuned from a common pre-trained model. Recently, several research efforts have been made on merging these large models into a single multi-task model, particularly with simple arithmetic on parameters. Such merging methodology faces a central challenge: interference between model parameters fine-tuned on different tasks. Few recent works have focused on designing a new fine-tuning scheme that can lead to small parameter interference, however at the cost of the performance of each task-specific fine-tuned model and thereby limiting that of a merged model. To improve the performance of a merged model, we note that a fine-tuning scheme should aim for (1) smaller parameter interference and (2) better performance of each fine-tuned model on the corresponding task. In this work, we aim to design a new fine-tuning objective function to work towards these two goals. In the course of this process, we find such objective function to be strikingly similar to sharpness-aware minimization (SAM) objective function, which aims to achieve generalization by finding flat minima. Drawing upon our observation, we propose to fine-tune pre-trained models via sharpness-aware minimization. The experimental and theoretical results showcase the effectiveness and orthogonality of our proposed approach, improving performance upon various merging and fine-tuning methods. Our code is available at https://github.com/baiklab/SAFT-Merge.

Paper Structure

This paper contains 29 sections, 2 theorems, 27 equations, 8 figures, 9 tables.

Key Result

Theorem 1

If models, parameterized by ${\bm{\theta}}_s$ and ${\bm{\theta}}_t$, are obtained by fine-tuning from a common pre-trained model via SAFT on their respective datasets, the models better satisfy the joint-task loss linearity. A proof is relegated to Appendix apdix:d due to space constraints.

Figures (8)

  • Figure 1: Weight disentanglement visualization of two task-specific models across two tasks. Each pixel in the heatmap corresponds to the weight disentanglement error $\xi(\alpha_1, \alpha_2)$ between a two-task-merged model, parameterized by ${\bm{\theta}}_{\text{merge}} = {\bm{\theta}}_0 + \alpha_1 \bm{\tau}_1 + \alpha_2 \bm{\tau}_2$, and two task-specific models, parameterized by ${\bm{\theta}}_0 + \alpha_1\bm{\tau}_1$ and ${\bm{\theta}}_0 + \alpha_2\bm{\tau}_2$, evaluated on task $1$ and task $2$. We use CLIP ViT-B/32 on EuroSAT-SUN397, DTD-EuroSAT, GTSRB-SVHN, and DTD-MNIST task pairs to plot these visualizations. The red box highlights the search space used to find the optimal task coefficients $\{\alpha_1, \alpha_2 \}$ for task arithmetic.
  • Figure 2: Weight disentanglement visualization of eight-task-merged models across two tasks. Each pixel in the heatmap corresponds to the disentanglement error $\xi(\alpha_1, \alpha_2)$ between an eight-task-merged model, parameterized by ${\bm{\theta}}_{\text{merge}} = {\bm{\theta}}_0 + \alpha_1 \bm{\tau}_1 + \alpha_2 \bm{\tau}_2 + \sum_{s \notin \{1, 2\}} \alpha_s\bm{\tau}_s$, and each task-specific model, evaluated on task $1$ and $2$. To visualize the landscape of the merged multi-task model on a 2D heatmap, we adjust only two task coefficients corresponding to the evaluation tasks. The models and evaluation task pairs used for the visualization are the same as those used in Figure \ref{['fig:wd']}. The meaning of the red box is the same as in Figure \ref{['fig:wd']}.
  • Figure 3: Verification of CTL between the merged model and task-specific models. We compare $\mathbb{E}_{\mathcal{D}^{(s)}\cup\mathcal{D}^{(t)}}[1 - \cos^{(\ell)}({\bm{x}}; 2\lambda\bm{\tau}_s, 2\lambda\bm{\tau}_t)]$ between sharpness-aware fine-tuning and SGD. The values for the last six blocks are evaluated on the two task pairs DTD-MNIST and EuroSAT-SUN397. We set the scaling factor $\lambda$ to $0.3$.
  • Figure 4: Joint-task loss landscape visualization of two task-specific models across two tasks. Each pixel in the heatmap corresponds to the loss values ${\mathcal{L}}({\bm{\theta}}_{\text{merge}};{\mathcal{D}}^{(1)}) + {\mathcal{L}}({\bm{\theta}}_{\text{merge}};{\mathcal{D}}^{(2)})$ of the two-task-merged model, parameterized by ${\bm{\theta}}_{\text{merge}} = {\bm{\theta}}_0 + \alpha_1 \bm{\tau}_1 + \alpha_2 \bm{\tau}_2$, evaluated on task $1$ and task $2$. The setting of the model, task pairs, and red box is the same as in Figure \ref{['fig:wd']}. We use CLIP ViT-B/32 on the EuroSAT-SUN397 and DTD-MNIST task pairs to plot these visualizations. The red box highlights the search space used to find the optimal task coefficients $\{ \alpha_1, \alpha_2 \}$ of task arithmetic.
  • Figure 5: Joint-task loss landscape visualization of eight-task-merged models across two tasks. Each pixel in the heatmap corresponds to the loss values ${\mathcal{L}}({\bm{\theta}}_{\text{merge}};{\mathcal{D}}^{(1)}) + {\mathcal{L}}({\bm{\theta}}_{\text{merge}};{\mathcal{D}}^{(2)})$ of the eight-task-merged model, parameterized by ${\bm{\theta}}_{\text{merge}} = {\bm{\theta}}_0 + \sum_{t=1}^8 \alpha_t\bm{\tau}_t$, evaluated on tasks $1$ and $2$. We adjust only the two task coefficients corresponding to the evaluation tasks to visualize the weight disentanglement on a 2D map, as in Figure \ref{['fig:wd_merging']}. The setting of the model, task pairs, and red box are the same as in Figure \ref{['fig:jtll']}. We use the same models and task pairs as illustrated in Figure \ref{['fig:jtll']}.
  • ...and 3 more figures

Theorems & Definitions (5)

  • Definition 1: Joint-task loss
  • Definition 2: Joint-task loss linearity
  • Theorem 1: SAFT induces joint-task loss linearity
  • Theorem 1: SAFT induces joint-task loss linearity
  • proof