Table of Contents
Fetching ...

SMART: Self-learning Meta-strategy Agent for Reasoning Tasks

Rongxing Liu, Kumar Shridhar, Manish Prajapat, Patrick Xia, Mrinmaya Sachan

TL;DR

SMART (Self-learning Meta-strategy Agent for Reasoning Tasks), a novel framework that enables LMs to autonomously learn and select the most effective strategies for various reasoning tasks, and reduces computational costs for refinement-based strategies, paving the way for more efficient and intelligent reasoning in LMs.

Abstract

Tasks requiring deductive reasoning, especially those involving multiple steps, often demand adaptive strategies such as intermediate generation of rationales or programs, as no single approach is universally optimal. While Language Models (LMs) can enhance their outputs through iterative self-refinement and strategy adjustments, they frequently fail to apply the most effective strategy in their first attempt. This inefficiency raises the question: Can LMs learn to select the optimal strategy in the first attempt, without a need for refinement? To address this challenge, we introduce SMART (Self-learning Meta-strategy Agent for Reasoning Tasks), a novel framework that enables LMs to autonomously learn and select the most effective strategies for various reasoning tasks. We model the strategy selection process as a Markov Decision Process and leverage reinforcement learning-driven continuous self-improvement to allow the model to find the suitable strategy to solve a given task. Unlike traditional self-refinement methods that rely on multiple inference passes or external feedback, SMART allows an LM to internalize the outcomes of its own reasoning processes and adjust its strategy accordingly, aiming for correct solutions on the first attempt. Our experiments across various reasoning datasets and with different model architectures demonstrate that SMART significantly enhances the ability of models to choose optimal strategies without external guidance (+15 points on the GSM8K dataset). By achieving higher accuracy with a single inference pass, SMART not only improves performance but also reduces computational costs for refinement-based strategies, paving the way for more efficient and intelligent reasoning in LMs.

SMART: Self-learning Meta-strategy Agent for Reasoning Tasks

TL;DR

SMART (Self-learning Meta-strategy Agent for Reasoning Tasks), a novel framework that enables LMs to autonomously learn and select the most effective strategies for various reasoning tasks, and reduces computational costs for refinement-based strategies, paving the way for more efficient and intelligent reasoning in LMs.

Abstract

Tasks requiring deductive reasoning, especially those involving multiple steps, often demand adaptive strategies such as intermediate generation of rationales or programs, as no single approach is universally optimal. While Language Models (LMs) can enhance their outputs through iterative self-refinement and strategy adjustments, they frequently fail to apply the most effective strategy in their first attempt. This inefficiency raises the question: Can LMs learn to select the optimal strategy in the first attempt, without a need for refinement? To address this challenge, we introduce SMART (Self-learning Meta-strategy Agent for Reasoning Tasks), a novel framework that enables LMs to autonomously learn and select the most effective strategies for various reasoning tasks. We model the strategy selection process as a Markov Decision Process and leverage reinforcement learning-driven continuous self-improvement to allow the model to find the suitable strategy to solve a given task. Unlike traditional self-refinement methods that rely on multiple inference passes or external feedback, SMART allows an LM to internalize the outcomes of its own reasoning processes and adjust its strategy accordingly, aiming for correct solutions on the first attempt. Our experiments across various reasoning datasets and with different model architectures demonstrate that SMART significantly enhances the ability of models to choose optimal strategies without external guidance (+15 points on the GSM8K dataset). By achieving higher accuracy with a single inference pass, SMART not only improves performance but also reduces computational costs for refinement-based strategies, paving the way for more efficient and intelligent reasoning in LMs.

Paper Structure

This paper contains 10 sections, 7 equations, 7 figures, 3 tables, 1 algorithm.

Figures (7)

  • Figure 1: Our proposed methodology: In the first step (initial sampling), an agent (LM) chooses a strategy and solves the given task with it. If it is correct, the process ends successfully. If an incorrect strategy is chosen, the agent iteratively refines its strategy, taking previous strategies into account. The process stops when a correct strategy is chosen to solve a task, or when a stopping criterion such as the number of attempts is reached. All correct strategies are used to further refine the model, and the process is repeated. During testing, we sample once from LM$_t$ without refinement.
  • Figure 2: Strategy distribution change over iterations for Gemma 7B model on GSM8K dataset.
  • Figure 3: Figure showing a comparison of the effects of different starting data points for the Gemma 7B model. SMART is compared against two baselines: the pre-trained Gemma 7B model and the fine-tuned Gemma 7B model on Llama3 8B data.
  • Figure 4: Qualitative example demonstrating that Gemma 7B model learnt to refinement strategy in its initial sampling stage, removing the need for refinement.
  • Figure 5: 8-shot Chain of Thought demonstration.
  • ...and 2 more figures