Table of Contents
Fetching ...

End-to-end Planner Training for Language Modeling

Nathan Cornille, Florian Mai, Jingyuan Sun, Marie-Francine Moens

TL;DR

This work proposes to use the predicted label probabilities as mixing weights to condition the LM on a weighted average of label embeddings in a differentiable manner, which enables joint fine-tuning of the planner and the LM, and allows the LM to draw on the full label distribution predicted by the planner, retaining more information.

Abstract

Through end-to-end training to predict the next token, LLMs have become valuable tools for various tasks. Enhancing their core training in language modeling can improve numerous downstream applications. A successful approach to enhance language modeling uses a separate planning module to predict abstract labels of future sentences and conditions the LM on these predictions. However, this method is non-differentiable, preventing joint end-to-end tuning of the planner with the LM. We propose an effective method to improve this approach by enabling joint fine-tuning of the planner and the LM. We show that a naive way of approximating the gradient of selecting a label via the straight-through estimator is not effective. Instead, we propose to use the predicted label probabilities as mixing weights to condition the LM on a weighted average of label embeddings in a differentiable manner. This not only enables joint fine-tuning of the planner and the LM, but also allows the LM to draw on the full label distribution predicted by the planner, retaining more information. Our experimental results show consistent improvements in perplexity.

End-to-end Planner Training for Language Modeling

TL;DR

This work proposes to use the predicted label probabilities as mixing weights to condition the LM on a weighted average of label embeddings in a differentiable manner, which enables joint fine-tuning of the planner and the LM, and allows the LM to draw on the full label distribution predicted by the planner, retaining more information.

Abstract

Through end-to-end training to predict the next token, LLMs have become valuable tools for various tasks. Enhancing their core training in language modeling can improve numerous downstream applications. A successful approach to enhance language modeling uses a separate planning module to predict abstract labels of future sentences and conditions the LM on these predictions. However, this method is non-differentiable, preventing joint end-to-end tuning of the planner with the LM. We propose an effective method to improve this approach by enabling joint fine-tuning of the planner and the LM. We show that a naive way of approximating the gradient of selecting a label via the straight-through estimator is not effective. Instead, we propose to use the predicted label probabilities as mixing weights to condition the LM on a weighted average of label embeddings in a differentiable manner. This not only enables joint fine-tuning of the planner and the LM, but also allows the LM to draw on the full label distribution predicted by the planner, retaining more information. Our experimental results show consistent improvements in perplexity.

Paper Structure

This paper contains 38 sections, 3 equations, 4 figures, 3 tables.

Figures (4)

  • Figure 1: Illustration of our proposed improvement. The planner predicts a distribution over actions, which is used as mixing weights to compute a weighted average of the action embeddings. This allows the planner to be fine-tuned jointly with the LM.
  • Figure 2: Illustration of the probing locations inside the model.
  • Figure 3: Plots showing probing performance at different layers and for different distances to the probe's target token.
  • Figure 4: Relative improvement/worsening of our metrics as we increase the fraction of planner-predicted actions from zero (equivalent to cornille2024learning OA) to one (equivalent to cornille2024learning PA). Some metrics are inverted, so that higher is better for all metrics.