Table of Contents
Fetching ...

Reinforce LLM Reasoning through Multi-Agent Reflection

Yurun Yuan, Tengyang Xie

TL;DR

The work tackles the challenge of improving LLM reasoning by enabling dynamic, multi-turn verification and refinement through a cooperative multi-agent framework. It introduces DPSDP, a direct policy search method that casts the refinement process as an MDp and trains an actor-critic LLM system with a generative critic using direct preference learning. Theoretical guarantees establish performance bounds, while extensive experiments across Ministral, Llama-3.1, and Qwen2.5 on MATH 500, GSM8K, and Olympiad/MMLU-Pro Math demonstrate notable gains in both first-turn and multi-turn accuracy, including strong out-of-distribution generalization. Ablation studies highlight the importance of multi-agent collaboration, Markovian state design, and restart data collection. Overall, DPSDP presents a scalable and effective approach to robust, test-time reasoning in LLMs with broad practical implications for complex problem-solving tasks.

Abstract

Leveraging more test-time computation has proven to be an effective way to boost the reasoning capabilities of large language models (LLMs). Among various methods, the verify-and-improve paradigm stands out for enabling dynamic solution exploration and feedback incorporation. However, existing approaches often suffer from restricted feedback spaces and lack of coordinated training of different parties, leading to suboptimal performance. To address this, we model this multi-turn refinement process as a Markov Decision Process and introduce DPSDP (Direct Policy Search by Dynamic Programming), a reinforcement learning algorithm that trains an actor-critic LLM system to iteratively refine answers via direct preference learning on self-generated data. Theoretically, DPSDP can match the performance of any policy within the training distribution. Empirically, we instantiate DPSDP with various base models and show improvements on both in- and out-of-distribution benchmarks. For example, on benchmark MATH 500, majority voting over five refinement steps increases first-turn accuracy from 58.2% to 63.2% with Ministral-based models. An ablation study further confirms the benefits of multi-agent collaboration and out-of-distribution generalization.

Reinforce LLM Reasoning through Multi-Agent Reflection

TL;DR

The work tackles the challenge of improving LLM reasoning by enabling dynamic, multi-turn verification and refinement through a cooperative multi-agent framework. It introduces DPSDP, a direct policy search method that casts the refinement process as an MDp and trains an actor-critic LLM system with a generative critic using direct preference learning. Theoretical guarantees establish performance bounds, while extensive experiments across Ministral, Llama-3.1, and Qwen2.5 on MATH 500, GSM8K, and Olympiad/MMLU-Pro Math demonstrate notable gains in both first-turn and multi-turn accuracy, including strong out-of-distribution generalization. Ablation studies highlight the importance of multi-agent collaboration, Markovian state design, and restart data collection. Overall, DPSDP presents a scalable and effective approach to robust, test-time reasoning in LLMs with broad practical implications for complex problem-solving tasks.

Abstract

Leveraging more test-time computation has proven to be an effective way to boost the reasoning capabilities of large language models (LLMs). Among various methods, the verify-and-improve paradigm stands out for enabling dynamic solution exploration and feedback incorporation. However, existing approaches often suffer from restricted feedback spaces and lack of coordinated training of different parties, leading to suboptimal performance. To address this, we model this multi-turn refinement process as a Markov Decision Process and introduce DPSDP (Direct Policy Search by Dynamic Programming), a reinforcement learning algorithm that trains an actor-critic LLM system to iteratively refine answers via direct preference learning on self-generated data. Theoretically, DPSDP can match the performance of any policy within the training distribution. Empirically, we instantiate DPSDP with various base models and show improvements on both in- and out-of-distribution benchmarks. For example, on benchmark MATH 500, majority voting over five refinement steps increases first-turn accuracy from 58.2% to 63.2% with Ministral-based models. An ablation study further confirms the benefits of multi-agent collaboration and out-of-distribution generalization.

Paper Structure

This paper contains 78 sections, 3 theorems, 44 equations, 3 figures, 6 tables, 3 algorithms.

Key Result

Theorem 1

Under asmp:coverageasmp:dpo-loss, if we choose $\beta = O\left(\frac{\sqrt{C_{\mathcal{S}}^\star C_{\mathcal{A}} \varepsilon_\mathsf{stat}}}{\log C_{\mathcal{A}}}\right)$, then DPSDP policy ${\widehat{\pi}}$ satisfies $\mathcal{J}(\pi^\star)-\mathcal{J}({\widehat{\pi}}) = O\left(H \sqrt{C_{\mathcal{

Figures (3)

  • Figure 1: Inference time. Given a problem $\boldsymbol{x}$, the actor $\pi_a$ generates an initial response $a_0$. The critic $\pi_c$ then provides feedback $a_1$, identifying potential errors in $a_0$. The actor iteratively refines its response based on the feedback, continuing this process for $L$ rounds. Finally, majority voting is applied to all generated answers to determine the final response $\tilde{a}$.
  • Figure 2: Model training. DPSDP first samples a complete trajectory $\tau = (\boldsymbol{x}, a_0, a_1, a_2)$ from the reference policy $\pi_\mathsf{ref}$. At each state along this trajectory, it generates $n$ responses to explore possible answers and feedback. $Q$-values of these $n$ candidate responses are then estimated as in \ref{['par:est-q-value']} and a pairwise preference dataset is extracted for subsequent DPO training on both the actor and critic.
  • Figure 3: Various metrics under different turns. Accuracies improve as the number of refinements increases. The rising pass1@turn-$k$ scores indicate that iterative refinement enables models to solve previously unsolved problems. Note that the decrease in maj1@t2 accuracy arises from the requirement that both responses (2 out of 2) must be correct to count toward maj1@t2.

Theorems & Definitions (3)

  • Theorem 1
  • Lemma 2: Performance difference lemma
  • Lemma 3