Table of Contents
Fetching ...

Divide-or-Conquer? Which Part Should You Distill Your LLM?

Zhuofeng Wu, He Bai, Aonan Zhang, Jiatao Gu, VG Vinod Vydiswaran, Navdeep Jaitly, Yizhe Zhang

TL;DR

This paper devise a strategy that breaks down reasoning tasks into a problem decomposition phase and a problem solving phase and shows that the strategy is able to outperform a single stage solution.

Abstract

Recent methods have demonstrated that Large Language Models (LLMs) can solve reasoning tasks better when they are encouraged to solve subtasks of the main task first. In this paper we devise a similar strategy that breaks down reasoning tasks into a problem decomposition phase and a problem solving phase and show that the strategy is able to outperform a single stage solution. Further, we hypothesize that the decomposition should be easier to distill into a smaller model compared to the problem solving because the latter requires large amounts of domain knowledge while the former only requires learning general problem solving strategies. We propose methods to distill these two capabilities and evaluate their impact on reasoning outcomes and inference cost. We find that we can distill the problem decomposition phase and at the same time achieve good generalization across tasks, datasets, and models. However, it is harder to distill the problem solving capability without losing performance and the resulting distilled model struggles with generalization. These results indicate that by using smaller, distilled problem decomposition models in combination with problem solving LLMs we can achieve reasoning with cost-efficient inference and local adaptation.

Divide-or-Conquer? Which Part Should You Distill Your LLM?

TL;DR

This paper devise a strategy that breaks down reasoning tasks into a problem decomposition phase and a problem solving phase and shows that the strategy is able to outperform a single stage solution.

Abstract

Recent methods have demonstrated that Large Language Models (LLMs) can solve reasoning tasks better when they are encouraged to solve subtasks of the main task first. In this paper we devise a similar strategy that breaks down reasoning tasks into a problem decomposition phase and a problem solving phase and show that the strategy is able to outperform a single stage solution. Further, we hypothesize that the decomposition should be easier to distill into a smaller model compared to the problem solving because the latter requires large amounts of domain knowledge while the former only requires learning general problem solving strategies. We propose methods to distill these two capabilities and evaluate their impact on reasoning outcomes and inference cost. We find that we can distill the problem decomposition phase and at the same time achieve good generalization across tasks, datasets, and models. However, it is harder to distill the problem solving capability without losing performance and the resulting distilled model struggles with generalization. These results indicate that by using smaller, distilled problem decomposition models in combination with problem solving LLMs we can achieve reasoning with cost-efficient inference and local adaptation.
Paper Structure (36 sections, 2 figures, 10 tables)

This paper contains 36 sections, 2 figures, 10 tables.

Figures (2)

  • Figure 1: Reasoning with a long thought chain using the black box LLM can be expensive and inflexible. We propose to dissect the decomposition and solving of the task, and distill only the decomposition capability to a less costly and more flexible student model, while still maintaining the original performance.
  • Figure 2: Solver models get lost sometimes.