Table of Contents
Fetching ...

Revisiting Weak-to-Strong Generalization in Theory and Practice: Reverse KL vs. Forward KL

Wei Yao, Wenkai Yang, Ziqiao Wang, Yankai Lin, Yong Liu

TL;DR

This work addresses weak-to-strong generalization in AI alignment by replacing forward KL/CE losses with reverse KL and reverse CE in the W2SG setting. It provides universal and tightened information-theoretic bounds, proving that reverse losses offer comparable guarantees and can even outperform the weak supervisor under certain conditions, especially after last-layer fine-tuning. Empirically, reverse KL and reverse CE yield superior performance over forward KL and standard CE across multiple model scales and alignment tasks (CAI-Harmless and HH-RLHF), and show increased resilience to label noise, with regularization offering additional gains. The results have practical implications for robustly leveraging weaker supervision to train stronger, safer models in high-stakes settings, while acknowledging potential overconfidence under extreme noise and the need for careful regularization and model design.

Abstract

As large language models advance toward superhuman performance, ensuring their alignment with human values and abilities grows increasingly complex. Weak-to-strong generalization offers a promising approach by leveraging predictions from weaker models to guide stronger systems, but its effectiveness could be constrained by the inherent noise and inaccuracies in these weak predictions. To address this, we propose a theoretically grounded approach that replaces forward KL divergence-whose mass-covering behavior risks overfitting to imperfect weak signals-with reverse KL divergence. Reverse KL divergence's zero-forcing effect prioritizes high-confidence predictions, effectively mitigating the influence of unreliable weak supervision. Theoretically, we extend existing bounds and derive tighter lower bounds for both forward and reverse KL divergence, establishing that reverse KL achieves at least comparable guarantees to forward KL. Notably, when a sufficiently pre-trained strong model is fine-tuned on the last linear layer, reverse KL guarantees that it outperforms its weak supervisor by the magnitude of their disagreement. Empirically, we demonstrate that reverse KL and reverse cross-entropy enable strong models to successfully outperform those trained with forward KL and standard cross-entropy across most settings, highlighting the practical advantages of these reverse losses.

Revisiting Weak-to-Strong Generalization in Theory and Practice: Reverse KL vs. Forward KL

TL;DR

This work addresses weak-to-strong generalization in AI alignment by replacing forward KL/CE losses with reverse KL and reverse CE in the W2SG setting. It provides universal and tightened information-theoretic bounds, proving that reverse losses offer comparable guarantees and can even outperform the weak supervisor under certain conditions, especially after last-layer fine-tuning. Empirically, reverse KL and reverse CE yield superior performance over forward KL and standard CE across multiple model scales and alignment tasks (CAI-Harmless and HH-RLHF), and show increased resilience to label noise, with regularization offering additional gains. The results have practical implications for robustly leveraging weaker supervision to train stronger, safer models in high-stakes settings, while acknowledging potential overconfidence under extreme noise and the need for careful regularization and model design.

Abstract

As large language models advance toward superhuman performance, ensuring their alignment with human values and abilities grows increasingly complex. Weak-to-strong generalization offers a promising approach by leveraging predictions from weaker models to guide stronger systems, but its effectiveness could be constrained by the inherent noise and inaccuracies in these weak predictions. To address this, we propose a theoretically grounded approach that replaces forward KL divergence-whose mass-covering behavior risks overfitting to imperfect weak signals-with reverse KL divergence. Reverse KL divergence's zero-forcing effect prioritizes high-confidence predictions, effectively mitigating the influence of unreliable weak supervision. Theoretically, we extend existing bounds and derive tighter lower bounds for both forward and reverse KL divergence, establishing that reverse KL achieves at least comparable guarantees to forward KL. Notably, when a sufficiently pre-trained strong model is fine-tuned on the last linear layer, reverse KL guarantees that it outperforms its weak supervisor by the magnitude of their disagreement. Empirically, we demonstrate that reverse KL and reverse cross-entropy enable strong models to successfully outperform those trained with forward KL and standard cross-entropy across most settings, highlighting the practical advantages of these reverse losses.

Paper Structure

This paper contains 47 sections, 10 theorems, 108 equations, 4 figures, 2 tables.

Key Result

Lemma 1

Let $L(\cdot, \cdot)$ be $\mathrm{KL}(\cdot, \cdot)$ or $\mathrm{CE}(\cdot, \cdot)$. Given the data domain $\mathcal{X}$, output domain $\mathcal{Y}$ and models $F_w, F^\star$ defined above. For any strong model $F_{sw}$, there holds where $C_1$ is a positive constant, $d(F_w, F_{sw})$ can be $\mathrm{KL}(F_w, F_{sw})$ or $\mathrm{KL}(F_{sw}, F_w)$, and $L(F^\star,F_{sw})$ and $L(F^\star,F_w)$ re

Figures (4)

  • Figure 1: Illustration of the mass-covering behavior of forward KL divergence and the mode-seeking behavior of reverse KL divergence, highlighting their roles in KD and W2SG. A Gaussian mixture distribution, representing the teacher's supervision in KD and W2SG, is approximated by fitting a single Gaussian distribution using both forward and reverse KL divergence as loss functions.
  • Figure 2: Results of GPT-2-series. "SC" denotes the strong ceiling model, and "A to B" indicates the use of weak teacher "A" to supervise strong student "B". The terms CE, RCE, KL, and RKL refer to CE loss, reverse CE loss, forward KL divergence loss, and reverse KL divergence loss, respectively. Error bars represent the standard deviation across three runs of the experiment.
  • Figure 3: Results of GPT-2 series on CAI-Harmless. "SC" denotes the strong ceiling model, and "A to B" indicates the use of weak teacher "A" to supervise strong student "B". The terms "Conf. CE" and "Reve. Conf. CE" refer to the auxiliary confidence loss with vanilla cross-entropy loss (\ref{['eq:confidence_loss']}) and reverse cross-entropy loss (\ref{['eq:reverse_confidence_loss']}), respectively. Error bars represent the standard deviation across three runs of the experiment.
  • Figure 4: Results of Pythia-series. "SC" denotes the strong ceiling model, and "A to B" indicates the use of weak teacher "A" to supervise strong student "B". The terms CE, RCE, KL, and RKL refer to cross-entropy loss, reverse cross-entropy loss, forward KL divergence loss, and reverse KL divergence loss, respectively. Error bars represent the standard deviation across three runs of the experiment.

Theorems & Definitions (28)

  • Definition 1: KL divergence losses
  • Definition 2: Cross-entropy losses
  • Lemma 1: Proved in \ref{['proof_lemma_inf']}
  • Theorem 1: Proved in \ref{['constant:theorem']}
  • Remark
  • Proposition 1: Proved in \ref{['proof:general_equation']}
  • Remark
  • Theorem 2: Proved in \ref{['theorem1_kl_loss']}
  • Remark
  • Theorem 3: Proved in \ref{['proof_non-realizable']}
  • ...and 18 more