Table of Contents
Fetching ...

$\texttt{LM}^\texttt{2}$: A Simple Society of Language Models Solves Complex Reasoning

Gurusha Juneja, Subhabrata Dutta, Tanmoy Chakraborty

TL;DR

LM2 introduces a modular, three-model framework (decomposer, solver, verifier) to tackle complex multi-step reasoning with explicit coordination learned via PPO and LoRA-based fine-tuning. The decomposer generates concepts and subproblems informed by solver outputs and verifier feedback, the verifier supplies granular error classifications, and the solver (GPT-3.5 in experiments) produces solutions. Across MATH, MedQA, and JEEBench, LM2 achieves strong out-of-domain generalization and outperforms competitive baselines by notable margins; ablations confirm the crucial roles of concepts, verification, and policy-based coordination. This approach offers a scalable path to robust reasoning by distributing cognitive tasks across specialized language models while controlling the reasoning trajectory through feedback-informed policy updates.

Abstract

Despite demonstrating emergent reasoning abilities, Large Language Models (LLMS) often lose track of complex, multi-step reasoning. Existing studies show that providing guidance via decomposing the original question into multiple subproblems elicits more robustness in LLM reasoning -- a decomposer generates the subproblems, and a solver solves each of these subproblems. However, these techniques fail to accommodate coordination between the decomposer and the solver modules (either in a single model or different specialized ones) -- the decomposer does not keep track of the ability of the solver to follow the decomposed reasoning. In this paper, we propose LM2 to address these challenges. LM2 modularizes the decomposition, solution, and verification into three different language models. The decomposer module identifies the key concepts necessary to solve the problem and generates step-by-step subquestions according to the reasoning requirement. The solver model generates the solution to the subproblems that are then checked by the verifier module; depending upon the feedback from the verifier, the reasoning context is constructed using the subproblems and the solutions. These models are trained to coordinate using policy learning. Exhaustive experimentation suggests the superiority of LM2 over existing methods on in- and out-domain reasoning problems, outperforming the best baselines by $8.1\%$ on MATH, $7.71\%$ on JEEBench, and $9.7\%$ on MedQA problems (code available at https://github.com/LCS2-IIITD/Language_Model_Multiplex).

$\texttt{LM}^\texttt{2}$: A Simple Society of Language Models Solves Complex Reasoning

TL;DR

LM2 introduces a modular, three-model framework (decomposer, solver, verifier) to tackle complex multi-step reasoning with explicit coordination learned via PPO and LoRA-based fine-tuning. The decomposer generates concepts and subproblems informed by solver outputs and verifier feedback, the verifier supplies granular error classifications, and the solver (GPT-3.5 in experiments) produces solutions. Across MATH, MedQA, and JEEBench, LM2 achieves strong out-of-domain generalization and outperforms competitive baselines by notable margins; ablations confirm the crucial roles of concepts, verification, and policy-based coordination. This approach offers a scalable path to robust reasoning by distributing cognitive tasks across specialized language models while controlling the reasoning trajectory through feedback-informed policy updates.

Abstract

Despite demonstrating emergent reasoning abilities, Large Language Models (LLMS) often lose track of complex, multi-step reasoning. Existing studies show that providing guidance via decomposing the original question into multiple subproblems elicits more robustness in LLM reasoning -- a decomposer generates the subproblems, and a solver solves each of these subproblems. However, these techniques fail to accommodate coordination between the decomposer and the solver modules (either in a single model or different specialized ones) -- the decomposer does not keep track of the ability of the solver to follow the decomposed reasoning. In this paper, we propose LM2 to address these challenges. LM2 modularizes the decomposition, solution, and verification into three different language models. The decomposer module identifies the key concepts necessary to solve the problem and generates step-by-step subquestions according to the reasoning requirement. The solver model generates the solution to the subproblems that are then checked by the verifier module; depending upon the feedback from the verifier, the reasoning context is constructed using the subproblems and the solutions. These models are trained to coordinate using policy learning. Exhaustive experimentation suggests the superiority of LM2 over existing methods on in- and out-domain reasoning problems, outperforming the best baselines by on MATH, on JEEBench, and on MedQA problems (code available at https://github.com/LCS2-IIITD/Language_Model_Multiplex).
Paper Structure (22 sections, 5 equations, 3 figures, 2 tables)

This paper contains 22 sections, 5 equations, 3 figures, 2 tables.

Figures (3)

  • Figure 1: The inference procedure of LM2 on a question from the MATH dataset. A question (in blue) is provided to the Solver LM that produces an incorrect answer (in red). The question is then provided to the Decomposer LM that generates the concepts and step-by-step subquestions (in lilac). Each subquestion is answered by the Solver LM, and the sub-answer is verified by a Verifier LM. If the Verifier LM approves the sub-answer, that subqustion-subanswer pair is added to the context of reasoning steps; otherwise, a new subquestion is generated. The question, concepts, subquestions, and subanswers are provided in context to the Decomposer LM to generate the next subquestion. Finally, the question, concepts, subquestions, and subanswers are provided to the Solver LM to generate the final answer (in green).
  • Figure 2: Comparison of token generation cost. We depict the average number of tokens generated by the solver model using different methods to solve the given question averaged over 50 questions from the JEEBench dataset.
  • Figure 3: Comparison of GPT-4, DaSLaM and LM2 on an example from MATH dataset.