Table of Contents
Fetching ...

GPO: Learning from Critical Steps to Improve LLM Reasoning

Jiahao Yu, Zelei Cheng, Xian Wu, Xinyu Xing

TL;DR

GPO introduces Guided Pivotal Optimization, a fine-tuning strategy for LLMs that concentrates learning on critical steps within multi-step reasoning trajectories. By estimating the advantage of each step, it identifies a pivotal point, resets the rollout at that step, and learns from explorations starting there, compatible with online PPO or offline DPO. Theoretical results bound regret for online learning and connect offline preference optimization to advantage-weighted RL, while extensive experiments across seven datasets and multiple baselines demonstrate consistent performance gains and generalizability. A user study shows alignment between GPO-identified steps and human judgments, supporting the claim that focusing on pivotal moments yields more reliable reasoning. The work also discusses scalability considerations and provides open-source code to promote reproducibility and further research.

Abstract

Large language models (LLMs) are increasingly used in various domains, showing impressive potential on different tasks. Recently, reasoning LLMs have been proposed to improve the \textit{reasoning} or \textit{thinking} capabilities of LLMs to solve complex problems. Despite the promising results of reasoning LLMs, enhancing the multi-step reasoning capabilities of LLMs still remains a significant challenge. While existing optimization methods have advanced the LLM reasoning capabilities, they often treat reasoning trajectories as a whole, without considering the underlying critical steps within the trajectory. In this paper, we introduce \textbf{G}uided \textbf{P}ivotal \textbf{O}ptimization (GPO), a novel fine-tuning strategy that dives into the reasoning process to enable more effective improvements. GPO first identifies the `critical step' within a reasoning trajectory - a point that the model must carefully proceed to succeed at the problem. We locate the critical step by estimating the advantage function. GPO then resets the policy to the critical step, samples the new rollout and prioritizes the learning process on those rollouts. This focus allows the model to learn more effectively from pivotal moments within the reasoning process to improve the reasoning performance. We demonstrate that GPO is a general strategy that can be integrated with various optimization methods to improve reasoning performance. Besides theoretical analysis, our experiments across challenging reasoning benchmarks show that GPO can consistently and significantly enhance the performance of existing optimization methods, showcasing its effectiveness and generalizability in improving LLM reasoning by concentrating on pivotal moments within the generation process.

GPO: Learning from Critical Steps to Improve LLM Reasoning

TL;DR

GPO introduces Guided Pivotal Optimization, a fine-tuning strategy for LLMs that concentrates learning on critical steps within multi-step reasoning trajectories. By estimating the advantage of each step, it identifies a pivotal point, resets the rollout at that step, and learns from explorations starting there, compatible with online PPO or offline DPO. Theoretical results bound regret for online learning and connect offline preference optimization to advantage-weighted RL, while extensive experiments across seven datasets and multiple baselines demonstrate consistent performance gains and generalizability. A user study shows alignment between GPO-identified steps and human judgments, supporting the claim that focusing on pivotal moments yields more reliable reasoning. The work also discusses scalability considerations and provides open-source code to promote reproducibility and further research.

Abstract

Large language models (LLMs) are increasingly used in various domains, showing impressive potential on different tasks. Recently, reasoning LLMs have been proposed to improve the \textit{reasoning} or \textit{thinking} capabilities of LLMs to solve complex problems. Despite the promising results of reasoning LLMs, enhancing the multi-step reasoning capabilities of LLMs still remains a significant challenge. While existing optimization methods have advanced the LLM reasoning capabilities, they often treat reasoning trajectories as a whole, without considering the underlying critical steps within the trajectory. In this paper, we introduce \textbf{G}uided \textbf{P}ivotal \textbf{O}ptimization (GPO), a novel fine-tuning strategy that dives into the reasoning process to enable more effective improvements. GPO first identifies the `critical step' within a reasoning trajectory - a point that the model must carefully proceed to succeed at the problem. We locate the critical step by estimating the advantage function. GPO then resets the policy to the critical step, samples the new rollout and prioritizes the learning process on those rollouts. This focus allows the model to learn more effectively from pivotal moments within the reasoning process to improve the reasoning performance. We demonstrate that GPO is a general strategy that can be integrated with various optimization methods to improve reasoning performance. Besides theoretical analysis, our experiments across challenging reasoning benchmarks show that GPO can consistently and significantly enhance the performance of existing optimization methods, showcasing its effectiveness and generalizability in improving LLM reasoning by concentrating on pivotal moments within the generation process.

Paper Structure

This paper contains 33 sections, 5 theorems, 35 equations, 3 figures, 4 tables, 2 algorithms.

Key Result

Theorem 5.2

Under assumption:bounded_q, with probability $1-\delta$, we have the following regret bound:

Figures (3)

  • Figure 1: Overview of our method. Given an initial trajectory generated by the policy $\pi$ for a reasoning task, GPO segments the trajectory into steps. It then identifies the most critical step via the MC simulation and resets the policy to the critical step to generate a new trajectory. The new trajectory is then added to the dataset or online buffer.
  • Figure 2: Ablation study results on BBH and MATH. We compare the performance of the standard GPO method and Satori's strategy that randomly identifies the critical step in the trajectory. Each bar represents the average performance of 3 runs, with error bars indicating the standard deviation.
  • Figure 3: Scaling behavior of GPO. Performance impact of varying number of MC samples (left) and applying GPO across different model sizes (right) on MATH using DPO/KTO.

Theorems & Definitions (5)

  • Theorem 5.2
  • Theorem 5.3
  • Lemma 1: performance difference lemma chang2024dataset
  • Theorem C.1
  • Lemma 2: song2023hybrid