Table of Contents
Fetching ...

PaD: Program-aided Distillation Can Teach Small Models Reasoning Better than Chain-of-thought Fine-tuning

Xuekai Zhu, Biqing Qi, Kaiyan Zhang, Xinwei Long, Zhouhan Lin, Bowen Zhou

TL;DR

This work tackles the challenge of imparting robust reasoning to small models without relying on inaccessible or giant LLMs. It introduces PaD, a pipeline that distills reasoning by synthesizing executable reasoning programs from LLMs, automatically filtering faulty data, and enhancing learning through self-refinement and step-by-step verification. Empirical results show that PaD can match or surpass certain large models on arithmetic and symbolic tasks using substantially fewer parameters and less data, though there is a trade-off with general-purpose abilities. The approach offers a practical route for deploying reasoning-enabled systems in resource-constrained settings and motivates further work on generalization and broader reasoning capabilities.

Abstract

While large language models (LLMs) excel in various natural language processing tasks, their huge size and the inaccessibility of parameters present challenges for practical deployment. Previous studies try to distill task-specific ability from LLMs to smaller models, using data synthesis and chain-of-thought (CoT) fine-tuning. However, synthetic CoT data often contains faulty reasoning, which deteriorates the quality of distillation, especially in reasoning capabilities. In this work, we propose Program-aided Distillation (PaD), which introduces reasoning programs to suppress the errors in distilled data, and thus achieves better distillation quality for reasoning tasks. In PaD, we utilize the reasoning program to substitute the CoT, allowing automated error checking of synthetic data. Further, through error injecting and further training, the small distilling model could iteratively self-refine the reasoning. Moreover, we conduct a step-wise beam search by step-by-step verifying to acquire more exact reasoning chains. We evaluate PaD on arithmetic reasoning, symbolic reasoning, and general ability. Experimental results demonstrate that smaller models using PaD can not only outperform certain LLMs~(e.g., LLaMA-1 13B) but also achieve strong improvement over baselines with a significantly smaller scale of parameters and data. The source code is publicly available at https://github.com/Xuekai-Zhu/pad.

PaD: Program-aided Distillation Can Teach Small Models Reasoning Better than Chain-of-thought Fine-tuning

TL;DR

This work tackles the challenge of imparting robust reasoning to small models without relying on inaccessible or giant LLMs. It introduces PaD, a pipeline that distills reasoning by synthesizing executable reasoning programs from LLMs, automatically filtering faulty data, and enhancing learning through self-refinement and step-by-step verification. Empirical results show that PaD can match or surpass certain large models on arithmetic and symbolic tasks using substantially fewer parameters and less data, though there is a trade-off with general-purpose abilities. The approach offers a practical route for deploying reasoning-enabled systems in resource-constrained settings and motivates further work on generalization and broader reasoning capabilities.

Abstract

While large language models (LLMs) excel in various natural language processing tasks, their huge size and the inaccessibility of parameters present challenges for practical deployment. Previous studies try to distill task-specific ability from LLMs to smaller models, using data synthesis and chain-of-thought (CoT) fine-tuning. However, synthetic CoT data often contains faulty reasoning, which deteriorates the quality of distillation, especially in reasoning capabilities. In this work, we propose Program-aided Distillation (PaD), which introduces reasoning programs to suppress the errors in distilled data, and thus achieves better distillation quality for reasoning tasks. In PaD, we utilize the reasoning program to substitute the CoT, allowing automated error checking of synthetic data. Further, through error injecting and further training, the small distilling model could iteratively self-refine the reasoning. Moreover, we conduct a step-wise beam search by step-by-step verifying to acquire more exact reasoning chains. We evaluate PaD on arithmetic reasoning, symbolic reasoning, and general ability. Experimental results demonstrate that smaller models using PaD can not only outperform certain LLMs~(e.g., LLaMA-1 13B) but also achieve strong improvement over baselines with a significantly smaller scale of parameters and data. The source code is publicly available at https://github.com/Xuekai-Zhu/pad.
Paper Structure (37 sections, 5 equations, 14 figures, 16 tables)

This paper contains 37 sections, 5 equations, 14 figures, 16 tables.

Figures (14)

  • Figure 1: Comparing CoT with program-aided. CoT from LLMs contains faulty reasoning but correct answers. Program-aided reasoning can easily check intermediate steps by an additional Python interpreter and reach the correct answer.
  • Figure 2: A comparative analysis of pre-trained large models and small models on the GSM8K benchmark cobbe2021training of math word problems. Small models employing PaD can surpass some larger models (e.g., LLaMA-1 13B), achieving nearly $50\%$ of GPT-4's performance.
  • Figure 3: The overview of program-aided distillation. I. synthesizing data from LLMs: We provide context examples and a question sample to LLMs, which induced reasoning programs from LLMs. Then, the additional Python interpreter automatically filters data. II. Fine-tuning small models: Utilizing the synthetic data, we fine-tune smaller models. III. Self-Refinement: Incorrect reasoning programs are reprocessed through the smaller models for iterative refinement. IV. Step-by-step verification: We adopt step-wise beam search to generate more faithful intermediate steps. $r_i^t$ indicates intermediate steps at time step $t$.
  • Figure 4: Ablation results in arithmetic reasoning and generic ability tasks. Compared with the base model and fine-tuning, PaD has achieved a significant improvement in mathematical reasoning ability. Employing self-refinement and step-by-step verification can bring further improvement. As the model specializes in mathematical reasoning, its general capabilities tend to decline.
  • Figure 5: Training vanilla Transformer models on the same size (7K) of PaD and CoT datasets, using an early stopping approach. The PaD demonstrates consistently lower training and evaluation losses across small/base/large models than the CoT fine-tuning.
  • ...and 9 more figures