Table of Contents
Fetching ...

Safety Alignment as Continual Learning: Mitigating the Alignment Tax via Orthogonal Gradient Projection

Guanglong Sun, Siyuan Zhang, Liyuan Wang, Jun Zhu, Hang Su, Yi Zhong

TL;DR

This work proposes Orthogonal Gradient Projection for Safety Alignment (OGPSA), a lightweight method that mitigates interference by constraining each safety update to be orthogonal to a learned subspace capturing general capabilities.

Abstract

Large Language Models (LLMs) often incur an alignment tax: safety post-training can reduce general utility (e.g., reasoning and coding). We argue that this tax primarily arises from continual-learning-style forgetting in sequential alignment, where distribution shift and conflicting objectives cause safety updates to overwrite pre-trained competencies. Accordingly, we cast safety alignment as a continual learning (CL) problem that must balance plasticity (acquiring safety constraints) and stability (preserving general abilities). We propose Orthogonal Gradient Projection for Safety Alignment (OGPSA), a lightweight method that mitigates interference by constraining each safety update to be orthogonal (in a first-order sense) to a learned subspace capturing general capabilities. Specifically, OGPSA estimates a low-rank capability subspace from gradients on a small reference set and projects the safety gradient onto its orthogonal complement before updating. This produces safety-directed updates that minimally perturb prior knowledge while retaining capacity for alignment. OGPSA is plug-and-play and integrates into standard post-training pipelines without large-scale replay, auxiliary objectives, or retraining. Across Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and sequential SFT$\rightarrow$DPO settings, OGPSA consistently improves the safety--utility Pareto frontier over standard baselines. For instance, on Qwen2.5-7B-Instruct under SFT$\rightarrow$DPO, OGPSA preserves strong safety while recovering general capability, improving SimpleQA from 0.53\% to 3.03\% and IFEval from 51.94\% to 63.96\%. Our source code is available at \href{https://github.com/SunGL001/OGPSA}{OGPSA}

Safety Alignment as Continual Learning: Mitigating the Alignment Tax via Orthogonal Gradient Projection

TL;DR

This work proposes Orthogonal Gradient Projection for Safety Alignment (OGPSA), a lightweight method that mitigates interference by constraining each safety update to be orthogonal to a learned subspace capturing general capabilities.

Abstract

Large Language Models (LLMs) often incur an alignment tax: safety post-training can reduce general utility (e.g., reasoning and coding). We argue that this tax primarily arises from continual-learning-style forgetting in sequential alignment, where distribution shift and conflicting objectives cause safety updates to overwrite pre-trained competencies. Accordingly, we cast safety alignment as a continual learning (CL) problem that must balance plasticity (acquiring safety constraints) and stability (preserving general abilities). We propose Orthogonal Gradient Projection for Safety Alignment (OGPSA), a lightweight method that mitigates interference by constraining each safety update to be orthogonal (in a first-order sense) to a learned subspace capturing general capabilities. Specifically, OGPSA estimates a low-rank capability subspace from gradients on a small reference set and projects the safety gradient onto its orthogonal complement before updating. This produces safety-directed updates that minimally perturb prior knowledge while retaining capacity for alignment. OGPSA is plug-and-play and integrates into standard post-training pipelines without large-scale replay, auxiliary objectives, or retraining. Across Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and sequential SFTDPO settings, OGPSA consistently improves the safety--utility Pareto frontier over standard baselines. For instance, on Qwen2.5-7B-Instruct under SFTDPO, OGPSA preserves strong safety while recovering general capability, improving SimpleQA from 0.53\% to 3.03\% and IFEval from 51.94\% to 63.96\%. Our source code is available at \href{https://github.com/SunGL001/OGPSA}{OGPSA}
Paper Structure (36 sections, 2 theorems, 22 equations, 4 figures, 8 tables, 1 algorithm)

This paper contains 36 sections, 2 theorems, 22 equations, 4 figures, 8 tables, 1 algorithm.

Key Result

Proposition 4.1

Let $f(\theta)=\mathcal{L}_{\mathrm{safe}}(\theta)$ with gradient $g=\nabla f(\theta)$, and let $\mathcal{S}_{\mathrm{gen}}(\theta)=\mathrm{span}\{g^{(i)}(\theta)\}_{i=1}^M$. Among all unit vectors $v$ satisfying $\langle g^{(i)}(\theta), v\rangle=0$ for all $i$, the maximally descending direction i

Figures (4)

  • Figure 1: Conceptual framework for reframing LLM Safety Alignment as a Constrained Continual Learning Problem. (A) Comparison of traditional CL and LLM Heterogeneous CL. (B) Safety alignment under anti-forgetting constraints.
  • Figure 2: Overall performance of alignment strategies on Qwen2.5-7B-Instruct. We report the aggregate Safety Score (avg. of 4 datasets) and General Capacity Score (avg. of 6 datasets); see Table \ref{['tab:main']} for details and Appendix Figure \ref{['fig:llama_results']} for the Llama3.1-8B-Instruct results.
  • Figure 3: Schematic illustration of the proposed Orthogonal Gradient Projection for Safety Alignment (OGPSA) framework.$g_{\text{ref1}}, g_{\text{ref2}}$: Reference gradients computed from representative general capability datasets (e.g., helpfulness, truthfulness). $g_{\text{safe}}$: The standard gradient derived from the safety alignment objective. $\tilde{g}_{\text{safe}}$: The projected safety gradient obtained by projecting $g_{\text{safe}}$ onto the orthogonal space of the general capability subspace.
  • Figure 4: Overall performance of alignment strategies on Llama3.1-8B-Instruct. We report the aggregate Safety Score (avg. of 4 datasets) and General Capacity Score (avg. of 6 datasets); see Table \ref{['tab:main']} for details.

Theorems & Definitions (4)

  • Proposition 4.1: Steepest Feasible Descent
  • proof
  • Lemma 2.1: Optimal Descent in Subspace
  • proof