Table of Contents
Fetching ...

Sharpness-Aware Minimization in Logit Space Efficiently Enhances Direct Preference Optimization

Haocheng Luo, Zehang Deng, Thanh-Toan Do, Mehrtash Harandi, Dinh Phung, Trung Le

Abstract

Direct Preference Optimization (DPO) has emerged as a popular algorithm for aligning pretrained large language models with human preferences, owing to its simplicity and training stability. However, DPO suffers from the recently identified squeezing effect (also known as likelihood displacement), where the probability of preferred responses decreases unintentionally during training. To understand and mitigate this phenomenon, we develop a theoretical framework that models the coordinate-wise dynamics in logit space. Our analysis reveals that negative-gradient updates cause residuals to expand rapidly along high-curvature directions, which underlies the squeezing effect, whereas Sharpness-Aware Minimization (SAM) can suppress this behavior through its curvature-regularization effect. Building on this insight, we investigate logits-SAM, a computationally efficient variant that perturbs only the output layer with negligible overhead. Extensive experiments on Pythia-2.8B, Mistral-7B, and Gemma-2B-IT across multiple datasets and benchmarks demonstrate that logits-SAM consistently improves the effectiveness of DPO and integrates seamlessly with other DPO variants. Code is available at https://github.com/RitianLuo/logits-sam-dpo.

Sharpness-Aware Minimization in Logit Space Efficiently Enhances Direct Preference Optimization

Abstract

Direct Preference Optimization (DPO) has emerged as a popular algorithm for aligning pretrained large language models with human preferences, owing to its simplicity and training stability. However, DPO suffers from the recently identified squeezing effect (also known as likelihood displacement), where the probability of preferred responses decreases unintentionally during training. To understand and mitigate this phenomenon, we develop a theoretical framework that models the coordinate-wise dynamics in logit space. Our analysis reveals that negative-gradient updates cause residuals to expand rapidly along high-curvature directions, which underlies the squeezing effect, whereas Sharpness-Aware Minimization (SAM) can suppress this behavior through its curvature-regularization effect. Building on this insight, we investigate logits-SAM, a computationally efficient variant that perturbs only the output layer with negligible overhead. Extensive experiments on Pythia-2.8B, Mistral-7B, and Gemma-2B-IT across multiple datasets and benchmarks demonstrate that logits-SAM consistently improves the effectiveness of DPO and integrates seamlessly with other DPO variants. Code is available at https://github.com/RitianLuo/logits-sam-dpo.
Paper Structure (54 sections, 14 theorems, 97 equations, 3 figures, 7 tables, 1 algorithm)

This paper contains 54 sections, 14 theorems, 97 equations, 3 figures, 7 tables, 1 algorithm.

Key Result

Proposition 3.1

In coordinates, ${\bm{H}}_{{\bm{W}}} =\bigl(\phi\phi^\top\bigr) \otimes {\bm{H}}_{{\bm{z}}}.$ Thus, if $\phi\neq \bm{0}$, then $\operatorname{rank}({\bm{H}}_{{\bm{W}}})=\operatorname{rank}({\bm{H}}_{{\bm{z}}}).$ Moreover, the second-order effect of any parameter perturbation depends only on the indu

Figures (3)

  • Figure 1: Training dynamics under different settings. (a--b) 1000-dimensional toy example with three classes, trained with a negative learning rate under GD, SAM ($\rho>0$), and SAM ($\rho<0$). Panel (a) shows the modal coefficients, and panel (b) shows the class residuals. (c) Real-data experiment on WebGPT Comparisons with GPT-2, comparing GD and SAM: the panel reports the log-probabilities of the chosen responses, the rejected responses, and max_conf, which denotes the model's most confident response. (d) Real-data experiment on the TL;DR dataset with Pythia-2.8B, showing the same three curves (chosen, rejected, and max_conf).
  • Figure 2: Efficiency comparison.
  • Figure 3: Learning dynamics of Mistral-7B on UltraFeedback. We compare AdamW and logits-SAM in terms of training loss, evaluation loss, and evaluation accuracy, and report curves for logits-SAM under different values of $\rho$.

Theorems & Definitions (23)

  • Proposition 3.1: Geometry of the logit space; simplified version of Proposition \ref{['app:geometry']}
  • Theorem 3.2: SAM dynamics in parameter and logit space; informal version of Theorem \ref{['app:dynamics']}
  • Proposition 3.3
  • Corollary 3.4: Modal dynamics in the eigenbasis of $\,{\bm{H}}_{{\bm{z}}}^{t}$
  • Lemma 3.5: One–step confidence ratios under GD, Lemma 1 of ren2024learning
  • Corollary 3.6: One–step confidence ratios under SAM, informal version of Corollary \ref{['app:coro prob']} and Corollary \ref{['app:coro-y']}
  • Proposition B.1: Geometry of the logit space and the parameter-logit correspondence
  • proof
  • Theorem B.2: Dynamics of SAM
  • proof
  • ...and 13 more