Table of Contents
Fetching ...

Steering Out-of-Distribution Generalization with Concept Ablation Fine-Tuning

Helena Casademunt, Caden Juang, Adam Karvonen, Samuel Marks, Senthooran Rajamanoharan, Neel Nanda

TL;DR

The paper tackles undesired OOD generalization in fine-tuned LLMs by introducing Concept Ablation Fine-Tuning (CAFT), which identifies misaligned latent directions via interpretability tools and suppresses them through projection during fine-tuning without modifying training data. CAFT is operationalized through two directions-identification approaches—PCA on activation differences and sparse autoencoders (SAEs)—and is shown to dramatically reduce emergent misalignment (up to 10x) while preserving in-distribution performance. It also improves robustness to spurious correlations in two multiple-choice tasks, achieving substantial gains in OOD accuracy with interpreted latents, and demonstrates competitive baselines against random or top-latent ablations. The work suggests a practical, data-free route to steering LLM generalization during fine-tuning, with implications for safer deployment and potential scalability to larger frontier models.

Abstract

Fine-tuning large language models (LLMs) can lead to unintended out-of-distribution generalization. Standard approaches to this problem rely on modifying training data, for example by adding data that better specify the intended generalization. However, this is not always practical. We introduce Concept Ablation Fine-Tuning (CAFT), a technique that leverages interpretability tools to control how LLMs generalize from fine-tuning, without needing to modify the training data or otherwise use data from the target distribution. Given a set of directions in an LLM's latent space corresponding to undesired concepts, CAFT works by ablating these concepts with linear projections during fine-tuning, steering the model away from unintended generalizations. We successfully apply CAFT to three fine-tuning tasks, including emergent misalignment, a phenomenon where LLMs fine-tuned on a narrow task generalize to give egregiously misaligned responses to general questions. Without any changes to the fine-tuning data, CAFT reduces misaligned responses by 10x without degrading performance on the training distribution. Overall, CAFT represents a novel approach for steering LLM generalization without modifying training data.

Steering Out-of-Distribution Generalization with Concept Ablation Fine-Tuning

TL;DR

The paper tackles undesired OOD generalization in fine-tuned LLMs by introducing Concept Ablation Fine-Tuning (CAFT), which identifies misaligned latent directions via interpretability tools and suppresses them through projection during fine-tuning without modifying training data. CAFT is operationalized through two directions-identification approaches—PCA on activation differences and sparse autoencoders (SAEs)—and is shown to dramatically reduce emergent misalignment (up to 10x) while preserving in-distribution performance. It also improves robustness to spurious correlations in two multiple-choice tasks, achieving substantial gains in OOD accuracy with interpreted latents, and demonstrates competitive baselines against random or top-latent ablations. The work suggests a practical, data-free route to steering LLM generalization during fine-tuning, with implications for safer deployment and potential scalability to larger frontier models.

Abstract

Fine-tuning large language models (LLMs) can lead to unintended out-of-distribution generalization. Standard approaches to this problem rely on modifying training data, for example by adding data that better specify the intended generalization. However, this is not always practical. We introduce Concept Ablation Fine-Tuning (CAFT), a technique that leverages interpretability tools to control how LLMs generalize from fine-tuning, without needing to modify the training data or otherwise use data from the target distribution. Given a set of directions in an LLM's latent space corresponding to undesired concepts, CAFT works by ablating these concepts with linear projections during fine-tuning, steering the model away from unintended generalizations. We successfully apply CAFT to three fine-tuning tasks, including emergent misalignment, a phenomenon where LLMs fine-tuned on a narrow task generalize to give egregiously misaligned responses to general questions. Without any changes to the fine-tuning data, CAFT reduces misaligned responses by 10x without degrading performance on the training distribution. Overall, CAFT represents a novel approach for steering LLM generalization without modifying training data.

Paper Structure

This paper contains 52 sections, 1 equation, 26 figures, 11 tables.

Figures (26)

  • Figure 1: Models trained on insecure code with standard fine-tuning methods exhibit misaligned behavior. Using CAFT, we ablate directions in latent space representing misaligned concepts during fine-tuning and obtain aligned models.
  • Figure 2: Examples of the data used to train and evaluate the emergent misalignment models. (Left) Training dataset of insecure code answers, where OOD generalization is ambiguous. The security vulnerability introduced is shown in bold. Some parts of the code have been omitted for space. (Right) Example question from the OOD general questions, showing examples of misaligned and aligned answers generated by the insecure and CAFT models, respectively.
  • Figure 3: Examples of a PC (left) and an SAE latent (right) from Qwen that are considered misaligned and are ablated while applying CAFT. For the PC, we show the min values (negative projection) because the max values are not interpretable. The shade of the text is the size of the projection, where positive values are blue and negative values are red. The bold token is the maximum or minimum.
  • Figure 4: Results from Qwen and Mistral models, showing the percentage of coherent responses that were misaligned for both CAFT methods.
  • Figure 5: Emergent misalignment results for Qwen. (Left) Misaligned response rate by question for the insecure and CAFT models. (Right) Misalignment and vulnerability rates comparing CAFT to training checkpoints and baselines described in section \ref{['sec:emergent-results']}. The error bars show the full range to the maximum and minimum misalignment and vulnerability values across 5 seeds. The arrow points in the direction of improvement with respect to training checkpoints. See Fig. \ref{['fig:mistral-emergent']} for the same experiments using Mistral.
  • ...and 21 more figures