Table of Contents
Fetching ...

Bridging the Training-Inference Gap in LLMs by Leveraging Self-Generated Tokens

Zhepeng Cen, Yao Liu, Siliang Zeng, Pratik Chaudhari, Huzefa Rangwala, George Karypis, Rasool Fakoor

TL;DR

Exposure bias causes a gap between how LLMs are trained (ground-truth conditioning) and how they generate text at inference time (self-generated history). The paper introduces two offline, self-generated-token methods—Batch-scheduled Sampling (BASH) and Reference-Answer-based Correction (RAC)—to make training conditions resemble inference without altering model architectures. Across summarization, general QA, and math QA benchmarks, BASH and RAC yield consistent improvements over strong demonstration-data baselines, and RAC additionally enables self-correction capabilities. The approach is practical, scalable, and beneficial for downstream alignment pipelines, offering a concrete path to closer training-inference alignment in LLMs.

Abstract

Language models are often trained to maximize the likelihood of the next token given past tokens in the training dataset. However, during inference time, they are utilized differently, generating text sequentially and auto-regressively by using previously generated tokens as input to predict the next one. Marginal differences in predictions at each step can cascade over successive steps, resulting in different distributions from what the models were trained for and potentially leading to unpredictable behavior. This paper proposes two simple approaches based on model own generation to address this discrepancy between the training and inference time. Our first approach is Batch-Scheduled Sampling, where, during training, we stochastically choose between the ground-truth token from the dataset and the model's own generated token as input to predict the next token. This is done in an offline manner, modifying the context window by interleaving ground-truth tokens with those generated by the model. Our second approach is Reference-Answer-based Correction, where we explicitly incorporate a self-correction capability into the model during training. This enables the model to effectively self-correct the gaps between the generated sequences and the ground truth data without relying on an external oracle model. By incorporating our proposed strategies during training, we have observed an overall improvement in performance compared to baseline methods, as demonstrated by our extensive experiments using summarization, general question-answering, and math question-answering tasks.

Bridging the Training-Inference Gap in LLMs by Leveraging Self-Generated Tokens

TL;DR

Exposure bias causes a gap between how LLMs are trained (ground-truth conditioning) and how they generate text at inference time (self-generated history). The paper introduces two offline, self-generated-token methods—Batch-scheduled Sampling (BASH) and Reference-Answer-based Correction (RAC)—to make training conditions resemble inference without altering model architectures. Across summarization, general QA, and math QA benchmarks, BASH and RAC yield consistent improvements over strong demonstration-data baselines, and RAC additionally enables self-correction capabilities. The approach is practical, scalable, and beneficial for downstream alignment pipelines, offering a concrete path to closer training-inference alignment in LLMs.

Abstract

Language models are often trained to maximize the likelihood of the next token given past tokens in the training dataset. However, during inference time, they are utilized differently, generating text sequentially and auto-regressively by using previously generated tokens as input to predict the next one. Marginal differences in predictions at each step can cascade over successive steps, resulting in different distributions from what the models were trained for and potentially leading to unpredictable behavior. This paper proposes two simple approaches based on model own generation to address this discrepancy between the training and inference time. Our first approach is Batch-Scheduled Sampling, where, during training, we stochastically choose between the ground-truth token from the dataset and the model's own generated token as input to predict the next token. This is done in an offline manner, modifying the context window by interleaving ground-truth tokens with those generated by the model. Our second approach is Reference-Answer-based Correction, where we explicitly incorporate a self-correction capability into the model during training. This enables the model to effectively self-correct the gaps between the generated sequences and the ground truth data without relying on an external oracle model. By incorporating our proposed strategies during training, we have observed an overall improvement in performance compared to baseline methods, as demonstrated by our extensive experiments using summarization, general question-answering, and math question-answering tasks.

Paper Structure

This paper contains 25 sections, 9 equations, 6 figures, 9 tables, 2 algorithms.

Figures (6)

  • Figure 1: How does RAC correct mistakes in model-generated responses? In this example, SFT model makes a mistake in calculating $3 * \$160,000 + \$2,000$, as shown in yellow. However, RAC corrects the error by replacing the wrong token, $4$, with the correct token, $2$. This is achieved by forcing the model to fit $\bar{z}$ that differ from the original generated response (highlighted in purple), enabling it to self-correct. This example is based on the GSM8K dataset cobbe2021training.
  • Figure 2: Visualization of the embedding distance between generated and reference responses. The left two figures are based on queries from the training set of the UltraChat-200K dataset, while the right two figures are from the test set. Corresponding queries for each figure are summarized at the bottom, with full queries available in \ref{['sec:app-exp-results']}. We generate 256 responses from models and compute their embedding distances to the reference responses. Each violin plot includes an inner box plot that displays the maximum, third quartile, median (indicated by a white line), first quartile, and minimum distances, while the shape of the violin represents the estimated probability density of the embedding distance.
  • Figure 3: The win rates of summarization task with different dataset sizes. Each results are averaged over three seeds. In each seed, the subset of trainig data is different and we train model on the different subset for SFT, BASH and RAC. The win rate is evaluated on the whole test set.
  • Figure 4: The AlpacaEval 2.0 LC win rates comparison of our methods and SFT. Each results are averaged over three generations. We leverage SFT on the UltraChat dataset to continue to train the existing SFT model and compare its performance with our methods. In the figure, the epoch 1 corresponds to the first iteration and epoch 2&3 correspond to second iteration for BASH and RAC. In the beginning of each iteration, we will offline generate BASH sequences or RAC labels by current model.
  • Figure 5: How does RAC correct mistakes in model-generated responses? In this example, SFT model incorrectly calculates Peter's age in the initial step, deriving it as $60/2=30$. However, the model fine-tuned with RAC produces the correct result. Specifically, the SFT model's incorrect calculation leads to an age of 30, while the ground truth is 34. To address this error, RAC labels the next token after "Peter will be $60/2$" as "+" instead of "=", guiding the model towards the correct computation. After training with RAC, the model successfully calculates Peter's age accurately, resulting in the correct answer.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Remark 1: Scheduled sampling is not scalable
  • Remark 2: Parameter $\beta$ should be chosen to be small