Table of Contents
Fetching ...

Superposed Decoding: Multiple Generations from a Single Autoregressive Inference Pass

Ethan Shen, Alan Fan, Sarah M. Pratt, Jae Sung Park, Matthew Wallingford, Sham M. Kakade, Ari Holtzman, Ranjay Krishna, Ali Farhadi, Aditya Kusupati

TL;DR

The proposed Superposed Decoding is a new decoding algorithm that generates drafts at the computation cost of one autoregressive inference pass, and is at least as coherent and factual as Nucleus Sampling and Greedy Decoding respectively.

Abstract

Many applications today provide users with multiple auto-complete drafts as they type, including GitHub's code completion, Gmail's smart compose, and Apple's messaging auto-suggestions. Under the hood, language models support this by running an autoregressive inference pass to provide a draft. Consequently, providing $k$ drafts to the user requires running an expensive language model $k$ times. To alleviate the computation cost of running $k$ inference passes, we propose Superposed Decoding, a new decoding algorithm that generates $k$ drafts at the computation cost of one autoregressive inference pass. We achieve this by feeding a superposition of the most recent token embeddings from the $k$ drafts as input to the next decoding step of the language model. At every inference step we combine the $k$ drafts with the top-$k$ tokens to get $k^2$ new drafts and cache the $k$ most likely options, using an n-gram interpolation with minimal compute overhead to filter out incoherent generations. Our experiments show that $k$ drafts from Superposed Decoding are at least as coherent and factual as Nucleus Sampling and Greedy Decoding respectively, while being at least $2.44\times$ faster for $k\ge3$. In a compute-normalized setting, user evaluations demonstrably favor text generated by Superposed Decoding over Nucleus Sampling. Superposed Decoding can also be combined with other decoding strategies, resulting in universal coverage gains when scaling inference time compute. Code and more examples open-sourced at https://github.com/RAIVNLab/SuperposedDecoding.

Superposed Decoding: Multiple Generations from a Single Autoregressive Inference Pass

TL;DR

The proposed Superposed Decoding is a new decoding algorithm that generates drafts at the computation cost of one autoregressive inference pass, and is at least as coherent and factual as Nucleus Sampling and Greedy Decoding respectively.

Abstract

Many applications today provide users with multiple auto-complete drafts as they type, including GitHub's code completion, Gmail's smart compose, and Apple's messaging auto-suggestions. Under the hood, language models support this by running an autoregressive inference pass to provide a draft. Consequently, providing drafts to the user requires running an expensive language model times. To alleviate the computation cost of running inference passes, we propose Superposed Decoding, a new decoding algorithm that generates drafts at the computation cost of one autoregressive inference pass. We achieve this by feeding a superposition of the most recent token embeddings from the drafts as input to the next decoding step of the language model. At every inference step we combine the drafts with the top- tokens to get new drafts and cache the most likely options, using an n-gram interpolation with minimal compute overhead to filter out incoherent generations. Our experiments show that drafts from Superposed Decoding are at least as coherent and factual as Nucleus Sampling and Greedy Decoding respectively, while being at least faster for . In a compute-normalized setting, user evaluations demonstrably favor text generated by Superposed Decoding over Nucleus Sampling. Superposed Decoding can also be combined with other decoding strategies, resulting in universal coverage gains when scaling inference time compute. Code and more examples open-sourced at https://github.com/RAIVNLab/SuperposedDecoding.
Paper Structure (34 sections, 10 equations, 16 figures, 6 tables, 1 algorithm)

This paper contains 34 sections, 10 equations, 16 figures, 6 tables, 1 algorithm.

Figures (16)

  • Figure 1: To generate multiple $k$ auto-complete suggestions for a prefix using an LM, the existing decoding methods like Nucleus Sampling need $k$ inference passes. In contrast, Superposed Decoding can generate $k$ suggestions at the cost of a single inference pass while being as coherent and factual.
  • Figure 2: Superposed Decoding relies on feeding a superposed token embedding -- based on the most recent tokens from the current $k$ drafts -- as the input during the auto-regressive inference step. This generates $k^2$ new drafts using the existing $k$ drafts and the top-$k$ output tokens at the new timestep. Finally, keep the top-$k$ drafts after filtering with an n-gram interpolation to improve coherency.
  • Figure 3: Llama-2-7B maintains the linear relationship between superposed embeddings and the component token embeddings, with mean cosine similarity $\ge0.6$ for the first 10 timesteps.
  • Figure 4: Qualitative text generations in a compute-normalized setting for Superposed Decoding and Nucleus Sampling with prefixes sampled from OpenWebText. See Appendix \ref{['appendix:spd_generations']} for more.
  • Figure 5: Superposed Decoding is as accurate as Greedy Decoding for P@1 and increases the fact-based coverage using multiple drafts (P@2,3) on TriviaQA (left) and Natural Questions (right).
  • ...and 11 more figures