Table of Contents
Fetching ...

Small Language Models Fine-tuned to Coordinate Larger Language Models improve Complex Reasoning

Gurusha Juneja, Subhabrata Dutta, Soumen Chakrabarti, Sunny Manchanda, Tanmoy Chakraborty

TL;DR

This work tackles the inefficiency of monolithic large language models for multi-step reasoning by decoupling problem decomposition from solution generation. It introduces DaSLaM, a framework where a 13B decomposer (finetuned with LoRA) generates subproblems for a solver LM treated as a black box, enabling solver-agnostic collaboration. Empirically, DaSLaM boosts performance on MATH, AQuA, and JEEBench, bringing GPT-3.5–level capabilities closer to GPT-4 with significantly fewer parameters and compute, and demonstrating robustness across different solver scales. Ablation studies show that a finetuned decomposer and feedback from the solver are crucial, and the modular approach outperforms larger decomposers driven by prompting alone. Overall, the paper demonstrates that heterogeneous, task-specific modules can achieve high reasoning performance with much greater efficiency and flexibility than monolithic, scale-driven models.

Abstract

Large Language Models (LLMs) prompted to generate chain-of-thought (CoT) exhibit impressive reasoning capabilities. Recent attempts at prompt decomposition toward solving complex, multi-step reasoning problems depend on the ability of the LLM to simultaneously decompose and solve the problem. A significant disadvantage is that foundational LLMs are typically not available for fine-tuning, making adaptation computationally prohibitive. We believe (and demonstrate) that problem decomposition and solution generation are distinct capabilites, better addressed in separate modules, than by one monolithic LLM. We introduce DaSLaM, which uses a decomposition generator to decompose complex problems into subproblems that require fewer reasoning steps. These subproblems are answered by a solver. We use a relatively small (13B parameters) LM as the decomposition generator, which we train using policy gradient optimization to interact with a solver LM (regarded as black-box) and guide it through subproblems, thereby rendering our method solver-agnostic. Evaluation on multiple different reasoning datasets reveal that with our method, a 175 billion parameter LM (text-davinci-003) can produce competitive or even better performance, compared to its orders-of-magnitude larger successor, GPT-4. Additionally, we show that DaSLaM is not limited by the solver's capabilities as a function of scale; e.g., solver LMs with diverse sizes give significant performance improvement with our solver-agnostic decomposition technique. Exhaustive ablation studies evince the superiority of our modular finetuning technique over exorbitantly large decomposer LLMs, based on prompting alone.

Small Language Models Fine-tuned to Coordinate Larger Language Models improve Complex Reasoning

TL;DR

This work tackles the inefficiency of monolithic large language models for multi-step reasoning by decoupling problem decomposition from solution generation. It introduces DaSLaM, a framework where a 13B decomposer (finetuned with LoRA) generates subproblems for a solver LM treated as a black box, enabling solver-agnostic collaboration. Empirically, DaSLaM boosts performance on MATH, AQuA, and JEEBench, bringing GPT-3.5–level capabilities closer to GPT-4 with significantly fewer parameters and compute, and demonstrating robustness across different solver scales. Ablation studies show that a finetuned decomposer and feedback from the solver are crucial, and the modular approach outperforms larger decomposers driven by prompting alone. Overall, the paper demonstrates that heterogeneous, task-specific modules can achieve high reasoning performance with much greater efficiency and flexibility than monolithic, scale-driven models.

Abstract

Large Language Models (LLMs) prompted to generate chain-of-thought (CoT) exhibit impressive reasoning capabilities. Recent attempts at prompt decomposition toward solving complex, multi-step reasoning problems depend on the ability of the LLM to simultaneously decompose and solve the problem. A significant disadvantage is that foundational LLMs are typically not available for fine-tuning, making adaptation computationally prohibitive. We believe (and demonstrate) that problem decomposition and solution generation are distinct capabilites, better addressed in separate modules, than by one monolithic LLM. We introduce DaSLaM, which uses a decomposition generator to decompose complex problems into subproblems that require fewer reasoning steps. These subproblems are answered by a solver. We use a relatively small (13B parameters) LM as the decomposition generator, which we train using policy gradient optimization to interact with a solver LM (regarded as black-box) and guide it through subproblems, thereby rendering our method solver-agnostic. Evaluation on multiple different reasoning datasets reveal that with our method, a 175 billion parameter LM (text-davinci-003) can produce competitive or even better performance, compared to its orders-of-magnitude larger successor, GPT-4. Additionally, we show that DaSLaM is not limited by the solver's capabilities as a function of scale; e.g., solver LMs with diverse sizes give significant performance improvement with our solver-agnostic decomposition technique. Exhaustive ablation studies evince the superiority of our modular finetuning technique over exorbitantly large decomposer LLMs, based on prompting alone.
Paper Structure (22 sections, 10 equations, 4 figures, 3 tables)

This paper contains 22 sections, 10 equations, 4 figures, 3 tables.

Figures (4)

  • Figure 1: Working example of DaSLaM on a mathematical reasoning question from the JEEBench dataset jeebench. In this example, the solver LM is text-davinci-003. In step , the solver is prompted to answer the question ( blue textbox) and it fails to answer correctly ( red textbox). A problem decomposing LM generates subproblems ( violet textboxes) conditioned on the original question and the initial response of the solver in step . In step , the solver answers these subproblems iteratively and appends to the prompt. Finally, the original problem is appended to the prompt in step , and the solver answers it correctly ( green textbox).
  • Figure 2: An example case study on a problem from the MATH dataset. GPT-3.5 is used as the solver LM with three different methods of prompting -- standard CoT, Least-to-most, and DaSLaM. Only DaSLaM is able to guide the model to the correct answer.
  • Figure 3: A case study on LLAMA-13B.
  • Figure 4: A case study on LLaMA-33B as solver model.