Reasoning Bias of Next Token Prediction Training
Pengxiao Lin, Zhongwang Zhang, Zhi-Qin John Xu
TL;DR
The paper investigates whether next-token prediction (NTP) or critical token prediction (CTP) better cultivates reasoning in transformer-based models. Through extensive from-scratch experiments on GPT-2-125M and a battery of reasoning datasets (including PrOntoQA, ProsQA, and anchor-function benchmarks), it shows that NTP's exposure to noisy supervision acts as a regularizer that improves generalization, robustness, and the emergence of reasoning capabilities during pretraining, while CTP often underperforms on reasoning tasks and is more effective for finetuning. The results reveal a clear reasoning bias induced by NTP: models trained with NTP learn to solve multi-step reasoning tasks more robustly and transfer earlier to related tasks, though they may be more susceptible to forgetting when finetuned on new data. The study also introduces the anchor function as a controlled synthetic probe to illustrate how NTP biases models toward simpler reasoning solutions and demonstrates that noise and initialization influence the bias, contributing to a nuanced view of training strategies for LLMs.
Abstract
Since the inception of Large Language Models (LLMs), the quest to efficiently train them for superior reasoning capabilities has been a pivotal challenge. The dominant training paradigm for LLMs is based on next token prediction (NTP). Alternative methodologies, called Critical Token Prediction (CTP), focused exclusively on specific critical tokens (such as the answer in Q\&A dataset), aiming to reduce the overfitting of extraneous information and noise. Contrary to initial assumptions, our research reveals that despite NTP's exposure to noise during training, it surpasses CTP in reasoning ability. We attribute this counterintuitive outcome to the regularizing influence of noise on the training dynamics. Our empirical analysis shows that NTP-trained models exhibit enhanced generalization and robustness across various benchmark reasoning datasets, demonstrating greater resilience to perturbations and achieving flatter loss minima. These findings illuminate that NTP is instrumental in fostering reasoning abilities during pretraining, whereas CTP is more effective for finetuning, thereby enriching our comprehension of optimal training strategies in LLM development.
