Table of Contents
Fetching ...

Reasoning Bias of Next Token Prediction Training

Pengxiao Lin, Zhongwang Zhang, Zhi-Qin John Xu

TL;DR

The paper investigates whether next-token prediction (NTP) or critical token prediction (CTP) better cultivates reasoning in transformer-based models. Through extensive from-scratch experiments on GPT-2-125M and a battery of reasoning datasets (including PrOntoQA, ProsQA, and anchor-function benchmarks), it shows that NTP's exposure to noisy supervision acts as a regularizer that improves generalization, robustness, and the emergence of reasoning capabilities during pretraining, while CTP often underperforms on reasoning tasks and is more effective for finetuning. The results reveal a clear reasoning bias induced by NTP: models trained with NTP learn to solve multi-step reasoning tasks more robustly and transfer earlier to related tasks, though they may be more susceptible to forgetting when finetuned on new data. The study also introduces the anchor function as a controlled synthetic probe to illustrate how NTP biases models toward simpler reasoning solutions and demonstrates that noise and initialization influence the bias, contributing to a nuanced view of training strategies for LLMs.

Abstract

Since the inception of Large Language Models (LLMs), the quest to efficiently train them for superior reasoning capabilities has been a pivotal challenge. The dominant training paradigm for LLMs is based on next token prediction (NTP). Alternative methodologies, called Critical Token Prediction (CTP), focused exclusively on specific critical tokens (such as the answer in Q\&A dataset), aiming to reduce the overfitting of extraneous information and noise. Contrary to initial assumptions, our research reveals that despite NTP's exposure to noise during training, it surpasses CTP in reasoning ability. We attribute this counterintuitive outcome to the regularizing influence of noise on the training dynamics. Our empirical analysis shows that NTP-trained models exhibit enhanced generalization and robustness across various benchmark reasoning datasets, demonstrating greater resilience to perturbations and achieving flatter loss minima. These findings illuminate that NTP is instrumental in fostering reasoning abilities during pretraining, whereas CTP is more effective for finetuning, thereby enriching our comprehension of optimal training strategies in LLM development.

Reasoning Bias of Next Token Prediction Training

TL;DR

The paper investigates whether next-token prediction (NTP) or critical token prediction (CTP) better cultivates reasoning in transformer-based models. Through extensive from-scratch experiments on GPT-2-125M and a battery of reasoning datasets (including PrOntoQA, ProsQA, and anchor-function benchmarks), it shows that NTP's exposure to noisy supervision acts as a regularizer that improves generalization, robustness, and the emergence of reasoning capabilities during pretraining, while CTP often underperforms on reasoning tasks and is more effective for finetuning. The results reveal a clear reasoning bias induced by NTP: models trained with NTP learn to solve multi-step reasoning tasks more robustly and transfer earlier to related tasks, though they may be more susceptible to forgetting when finetuned on new data. The study also introduces the anchor function as a controlled synthetic probe to illustrate how NTP biases models toward simpler reasoning solutions and demonstrates that noise and initialization influence the bias, contributing to a nuanced view of training strategies for LLMs.

Abstract

Since the inception of Large Language Models (LLMs), the quest to efficiently train them for superior reasoning capabilities has been a pivotal challenge. The dominant training paradigm for LLMs is based on next token prediction (NTP). Alternative methodologies, called Critical Token Prediction (CTP), focused exclusively on specific critical tokens (such as the answer in Q\&A dataset), aiming to reduce the overfitting of extraneous information and noise. Contrary to initial assumptions, our research reveals that despite NTP's exposure to noise during training, it surpasses CTP in reasoning ability. We attribute this counterintuitive outcome to the regularizing influence of noise on the training dynamics. Our empirical analysis shows that NTP-trained models exhibit enhanced generalization and robustness across various benchmark reasoning datasets, demonstrating greater resilience to perturbations and achieving flatter loss minima. These findings illuminate that NTP is instrumental in fostering reasoning abilities during pretraining, whereas CTP is more effective for finetuning, thereby enriching our comprehension of optimal training strategies in LLM development.

Paper Structure

This paper contains 43 sections, 9 equations, 12 figures.

Figures (12)

  • Figure 1: The schematic illustration comparing NTP and CTP. In the context of arithmetic addition tasks, CTP's loss function exclusively focuses on the answer, whereas NTP's loss encompasses the entire sequence, consequently introducing a certain degree of noise during the optimization process.
  • Figure 2: (a) Accuracy of NTP and CTP on the original/cloze PrOntoQA task over training epochs. In the original task, NTP eventually achieves perfect accuracy, while CTP plateaus around 80%. In the cloze task, the performance difference between NTP and CTP is enlarged. (b) 2-hop specific PrOntoQA: Performance of NTP and CTP on the specified key-answer PrOntoQA task. NTP maintains high accuracy without overfitting, whereas CTP overfits to the training data, leading to decreased accuracy on the reverse test set. (c) 1-hop specific PrOntoQA on OOV data: Accuracy of NTP and CTP on the 1-hop PrOntoQA task with OOV data. NTP achieves nearly 100% accuracy, while CTP stabilizes around 70%.
  • Figure 3: The description of different PrOntoQA tasks: Original, cloze, reverse, OOV test and its variation ProsQA coconut.
  • Figure 4: Performance comparison of NTP and CTP across various reasoning tasks. NTP consistently outperforms CTP in reasoning tasks, while performance on text classification tasks is more mixed. All the tasks are trained on the GPT-2 model (125M) from scratch to dismiss the effect of NTP in the pretraining stage. The accuracy is reported when the learning process becomes stable.
  • Figure 5: Accuracy on non-reasoning and reasoning solution of anchor function with different layers. The NTP could stably switch the non-reasoning solution to the reasoning solution. The error bars represent the standard deviation across 3-time runs on postnorm GPT2.
  • ...and 7 more figures