Table of Contents
Fetching ...

Mitigating Training Imbalance in LLM Fine-Tuning via Selective Parameter Merging

Yiming Ju, Ziyi Ni, Xingrun Xing, Zhixiong Zeng, hanyu Zhao, Siqi Fan, Zheng Zhang

TL;DR

Addressing training-order imbalance in supervised fine-tuning of LLMs, the paper demonstrates that sample position in the first epoch can strongly affect final losses. It proposes merging multiple SFT models trained with different data orders and introduces a parameter-selection merging strategy plus a resampling module. Across tasks including instruction-following, reasoning, and code generation on base models like Llama-2-7b, merged models consistently outperform single SFT, with resampling delivering additional gains of about 2 percentage points on average. The approach remains CPU-friendly and incurs no extra inference cost, offering a practical path to scaling robust SFT and motivating future work on larger models and multi-task merging.

Abstract

Supervised fine-tuning (SFT) is crucial for adapting Large Language Models (LLMs) to specific tasks. In this work, we demonstrate that the order of training data can lead to significant training imbalances, potentially resulting in performance degradation. Consequently, we propose to mitigate this imbalance by merging SFT models fine-tuned with different data orders, thereby enhancing the overall effectiveness of SFT. Additionally, we introduce a novel technique, "parameter-selection merging," which outperforms traditional weighted-average methods on five datasets. Further, through analysis and ablation studies, we validate the effectiveness of our method and identify the sources of performance improvements.

Mitigating Training Imbalance in LLM Fine-Tuning via Selective Parameter Merging

TL;DR

Addressing training-order imbalance in supervised fine-tuning of LLMs, the paper demonstrates that sample position in the first epoch can strongly affect final losses. It proposes merging multiple SFT models trained with different data orders and introduces a parameter-selection merging strategy plus a resampling module. Across tasks including instruction-following, reasoning, and code generation on base models like Llama-2-7b, merged models consistently outperform single SFT, with resampling delivering additional gains of about 2 percentage points on average. The approach remains CPU-friendly and incurs no extra inference cost, offering a practical path to scaling robust SFT and motivating future work on larger models and multi-task merging.

Abstract

Supervised fine-tuning (SFT) is crucial for adapting Large Language Models (LLMs) to specific tasks. In this work, we demonstrate that the order of training data can lead to significant training imbalances, potentially resulting in performance degradation. Consequently, we propose to mitigate this imbalance by merging SFT models fine-tuned with different data orders, thereby enhancing the overall effectiveness of SFT. Additionally, we introduce a novel technique, "parameter-selection merging," which outperforms traditional weighted-average methods on five datasets. Further, through analysis and ablation studies, we validate the effectiveness of our method and identify the sources of performance improvements.
Paper Structure (18 sections, 3 equations, 4 figures, 6 tables)

This paper contains 18 sections, 3 equations, 4 figures, 6 tables.

Figures (4)

  • Figure 1: Impact of training sample position at first epoch on final model losses of these samples (after 3 epochs of training). Panels (a) and (b) present the results on the GSM8k and Alpaca tasks, respectively. Panels (c) and (d) show the corresponding results from multiple experiments with different training orders.
  • Figure 2: Illustration comparing weighted-average method and the proposed parameter-selection method. Weighted-average merging calculates the weighted sum of all sub-model parameters at each parameter dimension, whereas parameter-selection merging selects parameters from a single sub-model. In the resampling module, parameters that equal those of the base model are replaced with parameters from alternative models.
  • Figure 3: Comparison of training losses across different models, with the first epoch sample position of the anchor model as the x-axis. Green lines represent final training losses of the anchor model; blue 'x' markers indicate losses of SFT models trained with various data order; red dots show losses of the merged model.
  • Figure 4: Comparison of validation loss between single and merged SFT models at various training steps.