Table of Contents
Fetching ...

The Factorization Curse: Which Tokens You Predict Underlie the Reversal Curse and More

Ouail Kitouni, Niklas Nolte, Diane Bouchacourt, Adina Williams, Mike Rabbat, Mark Ibrahim

TL;DR

This work identifies the reversal curse as a specific instance of the factorization curse, where standard left-to-right autoregressive training fails to preserve the same joint distribution under alternate factorizations. It formalizes the problem, analyzes its implications for knowledge retrieval, and introduces factorization-agnostic objectives (PLM and MLM-$\mathcal{U}$) designed to store and retrieve information across all factorizations. Through controlled synthetic tasks and aWikiReversal benchmark—based on GenWiki/DBpedia content—the authors show that scale, naive bidirectional training, and fixed-masking MLM do not resolve the reversal curse, while factorization-agnostic training significantly mitigates it and even improves planning capabilities. The findings suggest new directions for finetuning domain-specific data and improving knowledge storage, with practical impact on reducing hallucinations and enhancing reliable information retrieval in large language models.

Abstract

Today's best language models still struggle with hallucinations: factually incorrect generations, which impede their ability to reliably retrieve information seen during training. The reversal curse, where models cannot recall information when probed in a different order than was encountered during training, exemplifies this in information retrieval. We reframe the reversal curse as a factorization curse - a failure of models to learn the same joint distribution under different factorizations. Through a series of controlled experiments with increasing levels of realism including WikiReversal, a setting we introduce to closely simulate a knowledge intensive finetuning task, we find that the factorization curse is an inherent failure of the next-token prediction objective used in popular large language models. Moreover, we demonstrate reliable information retrieval cannot be solved with scale, reversed tokens, or even naive bidirectional-attention training. Consequently, various approaches to finetuning on specialized data would necessarily provide mixed results on downstream tasks, unless the model has already seen the right sequence of tokens. Across five tasks of varying levels of complexity, our results uncover a promising path forward: factorization-agnostic objectives can significantly mitigate the reversal curse and hint at improved knowledge storage and planning capabilities.

The Factorization Curse: Which Tokens You Predict Underlie the Reversal Curse and More

TL;DR

This work identifies the reversal curse as a specific instance of the factorization curse, where standard left-to-right autoregressive training fails to preserve the same joint distribution under alternate factorizations. It formalizes the problem, analyzes its implications for knowledge retrieval, and introduces factorization-agnostic objectives (PLM and MLM-) designed to store and retrieve information across all factorizations. Through controlled synthetic tasks and aWikiReversal benchmark—based on GenWiki/DBpedia content—the authors show that scale, naive bidirectional training, and fixed-masking MLM do not resolve the reversal curse, while factorization-agnostic training significantly mitigates it and even improves planning capabilities. The findings suggest new directions for finetuning domain-specific data and improving knowledge storage, with practical impact on reducing hallucinations and enhancing reliable information retrieval in large language models.

Abstract

Today's best language models still struggle with hallucinations: factually incorrect generations, which impede their ability to reliably retrieve information seen during training. The reversal curse, where models cannot recall information when probed in a different order than was encountered during training, exemplifies this in information retrieval. We reframe the reversal curse as a factorization curse - a failure of models to learn the same joint distribution under different factorizations. Through a series of controlled experiments with increasing levels of realism including WikiReversal, a setting we introduce to closely simulate a knowledge intensive finetuning task, we find that the factorization curse is an inherent failure of the next-token prediction objective used in popular large language models. Moreover, we demonstrate reliable information retrieval cannot be solved with scale, reversed tokens, or even naive bidirectional-attention training. Consequently, various approaches to finetuning on specialized data would necessarily provide mixed results on downstream tasks, unless the model has already seen the right sequence of tokens. Across five tasks of varying levels of complexity, our results uncover a promising path forward: factorization-agnostic objectives can significantly mitigate the reversal curse and hint at improved knowledge storage and planning capabilities.
Paper Structure (35 sections, 8 equations, 7 figures, 10 tables, 1 algorithm)

This paper contains 35 sections, 8 equations, 7 figures, 10 tables, 1 algorithm.

Figures (7)

  • Figure 1: (Left) Reversal curse from training a model on sentences with Paris before France. (Right) Left-to-right objective does not learn how to predict early tokens from later ones even if the information content is the same. The model overfits to a specific factorization of the joint distribution over tokens, and is unable to answer questions that require reasoning about a different factorization.
  • Figure 2: MLM struggles when entities span more tokens than the masked span. MLM-$\mathcal{U}$ encounters all possible masking fractions during training and does not suffer from this problem.
  • Figure 3: An example passage with a forward relation triple. The forward question queries the tail, backward queries the head. WikiReversal is a collection of passages and forward/backward QAs.
  • Figure 4: In panel (a) we compare MLM with varying masking ratios to MLM-$\mathcal{U}$. In panels (b) and (c) we visualize the two main principal components of representations learned via AR versus MLM-$\mathcal{U}$.
  • Figure 5: Star Graph Task: Illustration and Performance Comparison. The illustration shows the "Clever Hans" failure mode with teacher-forced AR (bachmann2024pitfalls adapted).
  • ...and 2 more figures

Theorems & Definitions (1)

  • Definition 1