Table of Contents
Fetching ...

Learning from the Undesirable: Robust Adaptation of Language Models without Forgetting

Yunhun Nam, Jaehyung Kim, Jongheon Jeong

TL;DR

Learning-from-the-Undesirable (LfU) tackles overfitting and forgetting when fine-tuning language models with limited data by enforcing representation-level consistency between the original model and an auxiliary model exposed to an undesirable update. This is achieved by augmenting the model with a low-rank LoRA or a representation-st steering component, performing a one-step gradient ascent to push toward undesirable behavior, and penalizing divergence in internal representations across all layers via a mean-squared error loss. The resulting objective, $\ell_{\text{LfU}}(\boldsymbol{\theta}, \boldsymbol{\theta}_{\text{aux}}) = \ell_{\text{SFT}}(\boldsymbol{\theta}) + \lambda \cdot \ell_{\text{cons.}}(\boldsymbol{\theta}, \boldsymbol{\theta}_{\text{aux}})$, regularizes fine-tuning to preserve general capabilities while enabling task specialization. Empirical results across single-task and multi-task settings show that LfU improves in-domain gain (e.g., up to $+16.8\%$ on math tasks) and enhances robustness to prompt variations and adversarial fine-tuning, with RepS offering a lightweight, faster variant that maintains competitive performance. Overall, LfU provides a practical, scalable approach to robust LM adaptation that maintains pretrained knowledge and improves generalization across diverse downstream tasks.

Abstract

Language models (LMs) are often adapted through supervised fine-tuning (SFT) to specialize their capabilities for downstream tasks. However, in typical scenarios where the fine-tuning data is limited, e.g., compared to pre-training, SFT can lead LMs to overfit, causing them to rely on spurious patterns within the target task or to compromise other broadly useful capabilities as a side effect of narrow specialization. In this paper, we propose Learning-from-the-Undesirable (LfU), a simple yet effective regularization scheme for SFT to mitigate overfitting issues when fine-tuning LMs with limited data. Specifically, we aim to regularize the fine-tuning process to favor solutions that are resilient to "undesirable" model updates, e.g., gradient ascent steps that steer the model toward undesirable behaviors. To this end, we propose a novel form of consistency regularization that directly aligns internal representations of the model with those after an undesirable update. By leveraging representation-level data augmentation through undesirable updates, LfU effectively promotes generalization under limited data. Our experiments on diverse LM downstream tasks show that LfU serves as an effective prior that enhances adaptability while preserving pretrained knowledge. For example, our LM from LfU achieves a 16.8% average improvement on math tasks compared to vanilla SFT on the same dataset, where the latter even leads to degraded performance on those tasks. Furthermore, LfU exhibits improved robustness to prompt variations, e.g., yielding a 92.1% lower standard deviation in output performances compared to SFT, highlighting its versatile effects.

Learning from the Undesirable: Robust Adaptation of Language Models without Forgetting

TL;DR

Learning-from-the-Undesirable (LfU) tackles overfitting and forgetting when fine-tuning language models with limited data by enforcing representation-level consistency between the original model and an auxiliary model exposed to an undesirable update. This is achieved by augmenting the model with a low-rank LoRA or a representation-st steering component, performing a one-step gradient ascent to push toward undesirable behavior, and penalizing divergence in internal representations across all layers via a mean-squared error loss. The resulting objective, , regularizes fine-tuning to preserve general capabilities while enabling task specialization. Empirical results across single-task and multi-task settings show that LfU improves in-domain gain (e.g., up to on math tasks) and enhances robustness to prompt variations and adversarial fine-tuning, with RepS offering a lightweight, faster variant that maintains competitive performance. Overall, LfU provides a practical, scalable approach to robust LM adaptation that maintains pretrained knowledge and improves generalization across diverse downstream tasks.

Abstract

Language models (LMs) are often adapted through supervised fine-tuning (SFT) to specialize their capabilities for downstream tasks. However, in typical scenarios where the fine-tuning data is limited, e.g., compared to pre-training, SFT can lead LMs to overfit, causing them to rely on spurious patterns within the target task or to compromise other broadly useful capabilities as a side effect of narrow specialization. In this paper, we propose Learning-from-the-Undesirable (LfU), a simple yet effective regularization scheme for SFT to mitigate overfitting issues when fine-tuning LMs with limited data. Specifically, we aim to regularize the fine-tuning process to favor solutions that are resilient to "undesirable" model updates, e.g., gradient ascent steps that steer the model toward undesirable behaviors. To this end, we propose a novel form of consistency regularization that directly aligns internal representations of the model with those after an undesirable update. By leveraging representation-level data augmentation through undesirable updates, LfU effectively promotes generalization under limited data. Our experiments on diverse LM downstream tasks show that LfU serves as an effective prior that enhances adaptability while preserving pretrained knowledge. For example, our LM from LfU achieves a 16.8% average improvement on math tasks compared to vanilla SFT on the same dataset, where the latter even leads to degraded performance on those tasks. Furthermore, LfU exhibits improved robustness to prompt variations, e.g., yielding a 92.1% lower standard deviation in output performances compared to SFT, highlighting its versatile effects.

Paper Structure

This paper contains 52 sections, 7 equations, 6 figures, 19 tables, 2 algorithms.

Figures (6)

  • Figure 1: Illustration of forgetting in SFT vs. preservation in LfU (Ours): Fine-tuning via SFT on Task A causes the model to forget prior knowledge related to Task B. In contrast, LfU successfully learns Task A while preserving the prior knowledge about Task B.
  • Figure 2: Performance comparison between baselines and LfU fine-tuned on Llama-3.1-8B. (a) and (b) show results fine-tuned on GSM8k and evaluated on in-domain and out-of-domain data, respectively. These results indicate that while prior methods struggle to adapt to in-domain examples and lead to only marginal improvements on out-of-domain data, LfU consistently achieves the best performance in both cases. (c) presents results from fine-tuning on the multitask dataset Alpagasus Dolly 3k, with evaluation across all tasks. LfU achieves the best overall performance.
  • Figure 3: Attack Success Rate (ASR) on (a) HEx-PHI and (b) PureBad after a few steps of adversarial fine-tuning on BeaverTails ji2023beavertails. We first align a Llama-3.1-8B via SFT (or LfU) with the harmless subset of BeaverTails, and then continue fine-tuning the model on the harmful subset (of BeaverTails) using SFT.
  • Figure 4: Overview of LfU: LfU promotes stable internal representations by enforcing consistency in internal representations between the original model ${\bm{\theta}}$ and auxiliary model ${\bm{\theta}}_{\text{aux}}$ that is optimized one step to induce undesirable behaviors. The auxiliary model ${\bm{\theta}}_{\text{aux}}$ is constructed by adding additional components to the original parameters ${\bm{\theta}}$, using either (1) a LoRA based method, where trainable low-rank matrices are added to each layer or (2) a representation steering based method, where a learnable steering vector is added to the internal representation at each layer. A gradient ascent step is then performed on the additional components by computing the gradient of the SFT objective with respect to ${\bm{\theta}}_{\text{aux}}$. The consistency loss is defined as the Mean Squared Error (MSE) between the internal representations of the original and auxiliary models across all layers.
  • Figure 5: Performance distribution of GSM8k-fine-tuned Llama-3.1-8B models on five prompt variations of GSM8k. We use ChatGPT for generating the variations. LfU achieves the highest average accuracy and the lowest standard deviation, demonstrating strong resilience to prompt variations.
  • ...and 1 more figures