Table of Contents
Fetching ...

Better Estimation of the Kullback--Leibler Divergence Between Language Models

Afra Amini, Tim Vieira, Ryan Cotterell

TL;DR

Estimating the KL divergence between language models is intractable, and standard Monte Carlo estimators suffer from high variance and occasional negative estimates. The authors introduce a Rao–Blackwellized Monte Carlo estimator that remains unbiased and provably has variance at most that of MC, with negligible additional cost. They extend this approach to gradient estimation, enabling more stable RLHF fine-tuning and better trade-offs between reward and KL divergence. Empirical results on sentiment-control tasks show substantially reduced variance and improved training stability, while RLHF experiments indicate that RB-based methods frequently reach favorable reward/KL Pareto fronts. Overall, the work provides a principled variance-reduction technique for KL estimation in language models with practical impact on LM alignment workflows.

Abstract

Estimating the Kullback--Leibler (KL) divergence between language models has many applications, e.g., reinforcement learning from human feedback (RLHF), interpretability, and knowledge distillation. However, computing the exact KL divergence between two arbitrary language models is intractable. Thus, practitioners often resort to sampling-based estimators. While it is easy to fashion a simple Monte Carlo (MC) estimator that provides an unbiased estimate of the KL divergence between language models, this estimator notoriously suffers from high variance and can even result in a negative estimate of the KL divergence, a non-negative quantity. In this paper, we introduce a Rao--Blackwellized estimator that is unbiased and provably has variance less than or equal to that of the standard Monte Carlo estimator. In an empirical study on sentiment-controlled fine-tuning, we show that our estimator provides more stable KL estimates and reduces variance substantially. Additionally, we derive an analogous Rao--Blackwellized estimator of the gradient of the KL divergence, which leads to more stable training and produces models that more frequently appear on the Pareto frontier of reward vs. KL compared to the ones trained with the MC estimator of the gradient.

Better Estimation of the Kullback--Leibler Divergence Between Language Models

TL;DR

Estimating the KL divergence between language models is intractable, and standard Monte Carlo estimators suffer from high variance and occasional negative estimates. The authors introduce a Rao–Blackwellized Monte Carlo estimator that remains unbiased and provably has variance at most that of MC, with negligible additional cost. They extend this approach to gradient estimation, enabling more stable RLHF fine-tuning and better trade-offs between reward and KL divergence. Empirical results on sentiment-control tasks show substantially reduced variance and improved training stability, while RLHF experiments indicate that RB-based methods frequently reach favorable reward/KL Pareto fronts. Overall, the work provides a principled variance-reduction technique for KL estimation in language models with practical impact on LM alignment workflows.

Abstract

Estimating the Kullback--Leibler (KL) divergence between language models has many applications, e.g., reinforcement learning from human feedback (RLHF), interpretability, and knowledge distillation. However, computing the exact KL divergence between two arbitrary language models is intractable. Thus, practitioners often resort to sampling-based estimators. While it is easy to fashion a simple Monte Carlo (MC) estimator that provides an unbiased estimate of the KL divergence between language models, this estimator notoriously suffers from high variance and can even result in a negative estimate of the KL divergence, a non-negative quantity. In this paper, we introduce a Rao--Blackwellized estimator that is unbiased and provably has variance less than or equal to that of the standard Monte Carlo estimator. In an empirical study on sentiment-controlled fine-tuning, we show that our estimator provides more stable KL estimates and reduces variance substantially. Additionally, we derive an analogous Rao--Blackwellized estimator of the gradient of the KL divergence, which leads to more stable training and produces models that more frequently appear on the Pareto frontier of reward vs. KL compared to the ones trained with the MC estimator of the gradient.

Paper Structure

This paper contains 29 sections, 18 theorems, 53 equations, 4 figures, 2 tables.

Key Result

Proposition 0

Consider the control variate MC estimator $\mu_{\textsc{cv}}$ defined in eq:klcv, and assume that $\mathop{\mathrm{\mathbb{E}}}\limits [g(\boldsymbol{Y})] < \infty$. Then $\mu_{\textsc{cv}}$ is an unbiased estimator, and its variance is given by

Figures (4)

  • Figure 1: Standard deviation of KL estimators across various prompts in the IMDB datasest.
  • Figure 2: Comparison of the Monte Carlo (MC) and Rao--Blackwellized (RB) estimators in the RLHF fine-tuning loop. We perform RLHF with each estimator $5$ times and plot the mean and standard deviation (in shades) of the average reward values and the KL at each fine-tuning step. We observe that the MC estimator is not as stable as the RB estimator and its performance varies significantly across different runs. However, RB estimator reliably offers a good balance between achieving low KL and high reward values in all runs.
  • Figure 3: Compared to models trained with MC esimator, models trained with RB appear on the Pareto front $78\%$ of the time.
  • Figure 4: Comparing the bias, variance, and consistency of the estimators as the sample size increases. The $\mu_{\textsc{cv}}$ estimator with $\alpha=1$ exhibits a higher standard deviation, particularly for neutral and negative prompts, where the variance becomes extremely large. In contrast, the RB estimator, $\mu_{\textsc{rb}}$, achieves the lowest standard deviation.

Theorems & Definitions (36)

  • Proposition 0
  • proof
  • Theorem 1
  • proof
  • Theorem 2: malagutti-etal-2024-role; Theorem 2.2
  • Theorem 3
  • proof
  • Theorem 4
  • proof
  • Proposition 4
  • ...and 26 more