Table of Contents
Fetching ...

Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning

Tianduo Wang, Shichen Li, Wei Lu

TL;DR

This work demonstrates that the reasoning abilities of small-scale LMs can be enhanced through self-training, a process where models learn from their own outputs, and shows that the conventional self-training can be further augmented by a preference learning algorithm called Direct Preference Optimization (DPO).

Abstract

Effective training of language models (LMs) for mathematical reasoning tasks demands high-quality supervised fine-tuning data. Besides obtaining annotations from human experts, a common alternative is sampling from larger and more powerful LMs. However, this knowledge distillation approach can be costly and unstable, particularly when relying on closed-source, proprietary LMs like GPT-4, whose behaviors are often unpredictable. In this work, we demonstrate that the reasoning abilities of small-scale LMs can be enhanced through self-training, a process where models learn from their own outputs. We also show that the conventional self-training can be further augmented by a preference learning algorithm called Direct Preference Optimization (DPO). By integrating DPO into self-training, we leverage preference data to guide LMs towards more accurate and diverse chain-of-thought reasoning. We evaluate our method across various mathematical reasoning tasks using different base models. Our experiments show that this approach not only improves LMs' reasoning performance but also offers a more cost-effective and scalable solution compared to relying on large proprietary LMs.

Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning

TL;DR

This work demonstrates that the reasoning abilities of small-scale LMs can be enhanced through self-training, a process where models learn from their own outputs, and shows that the conventional self-training can be further augmented by a preference learning algorithm called Direct Preference Optimization (DPO).

Abstract

Effective training of language models (LMs) for mathematical reasoning tasks demands high-quality supervised fine-tuning data. Besides obtaining annotations from human experts, a common alternative is sampling from larger and more powerful LMs. However, this knowledge distillation approach can be costly and unstable, particularly when relying on closed-source, proprietary LMs like GPT-4, whose behaviors are often unpredictable. In this work, we demonstrate that the reasoning abilities of small-scale LMs can be enhanced through self-training, a process where models learn from their own outputs. We also show that the conventional self-training can be further augmented by a preference learning algorithm called Direct Preference Optimization (DPO). By integrating DPO into self-training, we leverage preference data to guide LMs towards more accurate and diverse chain-of-thought reasoning. We evaluate our method across various mathematical reasoning tasks using different base models. Our experiments show that this approach not only improves LMs' reasoning performance but also offers a more cost-effective and scalable solution compared to relying on large proprietary LMs.
Paper Structure (31 sections, 6 equations, 7 figures, 5 tables, 2 algorithms)

This paper contains 31 sections, 6 equations, 7 figures, 5 tables, 2 algorithms.

Figures (7)

  • Figure 1: Our approach demonstrates superior performance on the GSM8K benchmark while minimizing the required compute cost, including both training and inference. Compute cost calculations are based on the methodology outlined by yuan2023scaling.
  • Figure 2: An illustration of the DPO-augmented Self-Training framework. Traditional self-training method uses the SFT model to generate the pseudo-labels for subsequent iterations. In contrast, our method enhances the SFT model with Direct Preference Optimization (DPO), using the optimized DPO model to produce the pseudo-labels.
  • Figure 3: An example from the GSM8K dataset. The calculation annotations are highlighted in blue. All calculation steps are wrapped within special tokens <<...>>. During decoding, the calculator will be triggered when such patterns exist and the model's output tokens will be overridden by the calculator results. Following cobbe2021gsm8k, the calculation is performed with the in-built python function eval().
  • Figure 4: Inference speed comparison between our methods (w/ and w/o calculator) and Calcformer kadlcik2023calcx with varying batch sizes. The results are measured on a single NVIDIA A40 GPU.
  • Figure 5: The performance of the proposed method on GSM8K over three iterations. For "iter 0", we report the performance of the SFT baselines, which are obtained after the warm-up stage.
  • ...and 2 more figures