Table of Contents
Fetching ...

PORT: Preference Optimization on Reasoning Traces

Salem Lahlou, Abdalgader Abubaker, Hakim Hacid

TL;DR

This work targets the gap in reasoning abilities of LLMs by applying offline preference optimization to reasoning traces (CoT). It introduces PORT, which combines supervised fine-tuning on reasoning steps with a subsequent Direct Preference Optimization stage, using two rejection-signal schemes (digit corruption and weak LLMs) to shape the model's stepwise reasoning. Empirical results on GSM8K, AQuA-RAT, ARC-Challenge, and LastLetterConcat show that DPO with digit corruption yields substantial gains (e.g., GSM8K and AQuA improvements) and that the approach transfers to non-mathematical reasoning tasks. The findings suggest that constructing high-quality reasoning-trace datasets can meaningfully boost general reasoning abilities, with robustness across base models (Falcon2-11B and Mistral-7B) and evidence of cross-dataset transfer.

Abstract

Preference optimization methods have been successfully applied to improve not only the alignment of large language models (LLMs) with human values, but also specific natural language tasks such as summarization and stylistic continuations. This paper proposes using preference optimization methods on Chain-of-Thought steps in order to improve the mathematical reasoning performances of language models. While the chosen answers are obtained from datasets that include reasoning traces, we propose two complementary schemes for generating rejected answers: weak LLM prompting, and digit corruption. Our approach leads to increased accuracy on the GSM8K and AQuA-RAT mathematical reasoning benchmarks for Falcon2-11B and Mistral-7B. Additionally, the improved abilities transfer to non-mathematical tasks, including the ARC benchmark and symbolic reasoning challenges. For example, our method can lead to up to relative 8.47% and 18.73% increases in accuracy on the GSM8K and AQuA benchmarks respectively, without any extra annotations. This work suggests that the path towards better language reasoning abilities goes through spending resources on creating high-quality datasets of reasoning traces.

PORT: Preference Optimization on Reasoning Traces

TL;DR

This work targets the gap in reasoning abilities of LLMs by applying offline preference optimization to reasoning traces (CoT). It introduces PORT, which combines supervised fine-tuning on reasoning steps with a subsequent Direct Preference Optimization stage, using two rejection-signal schemes (digit corruption and weak LLMs) to shape the model's stepwise reasoning. Empirical results on GSM8K, AQuA-RAT, ARC-Challenge, and LastLetterConcat show that DPO with digit corruption yields substantial gains (e.g., GSM8K and AQuA improvements) and that the approach transfers to non-mathematical reasoning tasks. The findings suggest that constructing high-quality reasoning-trace datasets can meaningfully boost general reasoning abilities, with robustness across base models (Falcon2-11B and Mistral-7B) and evidence of cross-dataset transfer.

Abstract

Preference optimization methods have been successfully applied to improve not only the alignment of large language models (LLMs) with human values, but also specific natural language tasks such as summarization and stylistic continuations. This paper proposes using preference optimization methods on Chain-of-Thought steps in order to improve the mathematical reasoning performances of language models. While the chosen answers are obtained from datasets that include reasoning traces, we propose two complementary schemes for generating rejected answers: weak LLM prompting, and digit corruption. Our approach leads to increased accuracy on the GSM8K and AQuA-RAT mathematical reasoning benchmarks for Falcon2-11B and Mistral-7B. Additionally, the improved abilities transfer to non-mathematical tasks, including the ARC benchmark and symbolic reasoning challenges. For example, our method can lead to up to relative 8.47% and 18.73% increases in accuracy on the GSM8K and AQuA benchmarks respectively, without any extra annotations. This work suggests that the path towards better language reasoning abilities goes through spending resources on creating high-quality datasets of reasoning traces.
Paper Structure (36 sections, 3 equations, 5 figures, 6 tables)

This paper contains 36 sections, 3 equations, 5 figures, 6 tables.

Figures (5)

  • Figure 1: Illustration of the creation process of a preference dataset with two complementary approaches to generate rejected answers. The preference dataset is used to fine-tune a reference model using a Direct Preference Optimization (DPO) or one of its variants, after a supervised fine-tuning (SFT) step.
  • Figure 2: Robustness analysis, using Mistral-7B as base model: GSM8K accuracy - Comparison of different corruption schemes.
  • Figure 3: DPO hyperparameter search. The y axis corresponds to the accuracy on the test set of GSM8K.
  • Figure 4: DPO variants hyperparameter search. The y axis corresponds to the accuracy on the test set of GSM8K. The learning rate $8 \times 10^{-6}$ and number of epochs ($1$) used are the same as DPO.
  • Figure 5: DPO with weak LLM generation for rejected answers. Comparison of different versions of Llama-7B. The y axis corresponds to the accuracy on the test set of GSM8K. The learning rate use is $8 \times 10^{-6}$ and number of epochs is $1$.