Table of Contents
Fetching ...

Explaining and Improving Contrastive Decoding by Extrapolating the Probabilities of a Huge and Hypothetical LM

Haw-Shiuan Chang, Nanyun Peng, Mohit Bansal, Anil Ramakrishna, Tagyoung Chung

TL;DR

A new unsupervised decoding method called Asymptotic Probability Decoding (APD), which explicitly extrapolates the probability curves from the LMs of different sizes to infer the asymptotic probabilities from an infinitely large LM without inducing more inference costs than CD.

Abstract

Contrastive decoding (CD) (Li et al., 2023) improves the next-token distribution of a large expert language model (LM) using a small amateur LM. Although CD is applied to various LMs and domains to enhance open-ended text generation, it is still unclear why CD often works well, when it could fail, and how we can make it better. To deepen our understanding of CD, we first theoretically prove that CD could be viewed as linearly extrapolating the next-token logits from a huge and hypothetical LM. We also highlight that the linear extrapolation could make CD unable to output the most obvious answers that have already been assigned high probabilities by the amateur LM. To overcome CD's limitation, we propose a new unsupervised decoding method called $\mathbf{A}$symptotic $\mathbf{P}$robability $\mathbf{D}$ecoding (APD). APD explicitly extrapolates the probability curves from the LMs of different sizes to infer the asymptotic probabilities from an infinitely large LM without inducing more inference costs than CD. In FactualityPrompts, an open-ended text generation benchmark, sampling using APD significantly boosts factuality in comparison to the CD sampling and its variants, and achieves state-of-the-art results for Pythia 6.9B and OPT 6.7B. Furthermore, in five commonsense QA datasets, APD is often significantly better than CD and achieves a similar effect of using a larger LLM. For example, the perplexity of APD on top of Pythia 6.9B is even lower than the perplexity of Pythia 12B in CommonsenseQA and LAMBADA.

Explaining and Improving Contrastive Decoding by Extrapolating the Probabilities of a Huge and Hypothetical LM

TL;DR

A new unsupervised decoding method called Asymptotic Probability Decoding (APD), which explicitly extrapolates the probability curves from the LMs of different sizes to infer the asymptotic probabilities from an infinitely large LM without inducing more inference costs than CD.

Abstract

Contrastive decoding (CD) (Li et al., 2023) improves the next-token distribution of a large expert language model (LM) using a small amateur LM. Although CD is applied to various LMs and domains to enhance open-ended text generation, it is still unclear why CD often works well, when it could fail, and how we can make it better. To deepen our understanding of CD, we first theoretically prove that CD could be viewed as linearly extrapolating the next-token logits from a huge and hypothetical LM. We also highlight that the linear extrapolation could make CD unable to output the most obvious answers that have already been assigned high probabilities by the amateur LM. To overcome CD's limitation, we propose a new unsupervised decoding method called symptotic robability ecoding (APD). APD explicitly extrapolates the probability curves from the LMs of different sizes to infer the asymptotic probabilities from an infinitely large LM without inducing more inference costs than CD. In FactualityPrompts, an open-ended text generation benchmark, sampling using APD significantly boosts factuality in comparison to the CD sampling and its variants, and achieves state-of-the-art results for Pythia 6.9B and OPT 6.7B. Furthermore, in five commonsense QA datasets, APD is often significantly better than CD and achieves a similar effect of using a larger LLM. For example, the perplexity of APD on top of Pythia 6.9B is even lower than the perplexity of Pythia 12B in CommonsenseQA and LAMBADA.

Paper Structure

This paper contains 31 sections, 1 theorem, 12 equations, 10 figures, 7 tables, 1 algorithm.

Key Result

Theorem 1

If a) the ALM's temperature $T>1$, and b) the logits of LMs and the logarithm of the LM sizes have a linear relationship, then the logit of contrastive decoding (CD) for the token $w$$L^{CD}_c(w)= (1-\frac{1}{T} )L^{HLM}_c(w)$, where $L^{HLM}_c(w)$ is the logit of a LM with size $s^{HLM} = \left(\fr

Figures (10)

  • Figure 1: Given a simple question with clues for which a tiny amateur LM could provide a correct answer, contrastive decoding (CD) could have a "obvious blindness" (i.e., assigning a higher logit to an uncommon answer Invertebrate than the most obvious answer Bees). In contrast, the proposed asymptotic probability decoding (APD) correctly assigns the highest probability to Bees by leveraging the probabilities from multiple LMs of different sizes to extrapolate the probabilities from an infinitely large and hypothetical LM.
  • Figure 2: Illustration of our proof for \ref{['thm:main_theorem']}. Teal bars are original logits, and red bars are the logits scaled by $1-\frac{1}{T}$. $L^d = L^{ALM} - L^{ELM}$, $S^d$ is the size difference of ELM and ALM in a logarithm space. We drop the word $w$ and the context $c$ in the notations of this figure for simplicity.
  • Figure 3: Fine-tuning ALM to predict the asymptotic probability ($P_c^{AP}$). During the training time, the predicted $\hat{P}_c^{AP}(\text{Kenya})$ and the empirical probabilities from the LLM family $\{p(w|c,\theta_{s_i})\}_{i=1}^N$ are inputted into an MLP. If $\hat{P}_c^{AP}(\text{Kenya})$ is too high to model the empirical probabilities well, the probability curve outputted by MLP would be far away from the empirical probabilities and thus, incur a high loss 1 and loss 2. Then, the resulting gradients would be backpropagated through MLP and ALM and reduce $\hat{P}_c^{AP}(\text{Kenya})$. Finally, we add a regularization loss 3 to control the changes in the ALM's logit output.
  • Figure 4: Factuality evaluation of the open-ended text generation using FactualityPrompts. The x-axis is a diversity metric (dist-2) and the y-axis is a hallucination metric (NE$_{ER}$), the ratio of containing potential hallucinated entities in generation, so the curves closer to the lower right corner are better.
  • Figure 5: Comparison of empirical probability curves (blue) and probabilities predicted by APD (orange) and CD (green and red). The next token $w$ has the highest probability in ELM. The LLM family is Pythia.
  • ...and 5 more figures

Theorems & Definitions (2)

  • Theorem 1
  • proof