Table of Contents
Fetching ...

Beyond Imitation: Learning Key Reasoning Steps from Dual Chain-of-Thoughts in Reasoning Distillation

Chengwei Dai, Kun Li, Wei Zhou, Songlin Hu

TL;DR

This work targets distilling chain-of-thought reasoning from large LLMs into smaller models by focusing on key reasoning steps rather than surface-level reasoning patterns. It introduces EDIT, which builds dual CoTs data and uses minimum edit distance to identify and weight the crucial steps, coupled with a mistake-driven learning objective (KRSL). Empirical results on in-domain and out-of-domain benchmarks show EDIT consistently outperforms standard CoT distillation and other baselines, with robust gains across model sizes and architectures. Analyses reveal that learning from logical mistakes in dual CoTs yields particularly strong benefits and that higher-quality, dual CoTs correlate with more faithful key-step reasoning, supporting the practical impact of learning key reasoning steps for scalable reasoning in SLMs.

Abstract

As Large Language Models (LLMs) scale up and gain powerful Chain-of-Thoughts (CoTs) reasoning abilities, practical resource constraints drive efforts to distill these capabilities into more compact Smaller Language Models (SLMs). We find that CoTs consist mainly of simple reasoning forms, with a small proportion ($\approx 4.7\%$) of key reasoning steps that truly impact conclusions. However, previous distillation methods typically involve supervised fine-tuning student SLMs only on correct CoTs data produced by teacher LLMs, resulting in students struggling to learn the key reasoning steps, instead imitating the teacher's reasoning forms and making errors or omissions on these steps. To address these issues, drawing an analogy to human learning, where analyzing mistakes according to correct solutions often reveals the crucial steps leading to successes or failures, we propose mistak\textbf{E}-\textbf{D}riven key reason\textbf{I}ng step distilla\textbf{T}ion (\textbf{EDIT}), a novel method that further aids SLMs learning key reasoning steps rather than mere simple fine-tuning. Firstly, to expose these crucial steps in CoTs, we design specific prompts to generate dual CoTs data with similar reasoning paths but divergent conclusions. Then, we apply the minimum edit distance algorithm on the dual CoTs data to locate these key steps and optimize the likelihood of these steps. Extensive experiments validate the effectiveness of EDIT across both in-domain and out-of-domain benchmark reasoning datasets. Further analysis shows that EDIT can generate high-quality CoTs with more correct key reasoning steps. Notably, we also explore how different mistake patterns affect performance and find that EDIT benefits more from logical errors than from knowledge or mathematical calculation errors in dual CoTs\footnote{Code can be found at \url{https://github.com/C-W-D/EDIT}}.

Beyond Imitation: Learning Key Reasoning Steps from Dual Chain-of-Thoughts in Reasoning Distillation

TL;DR

This work targets distilling chain-of-thought reasoning from large LLMs into smaller models by focusing on key reasoning steps rather than surface-level reasoning patterns. It introduces EDIT, which builds dual CoTs data and uses minimum edit distance to identify and weight the crucial steps, coupled with a mistake-driven learning objective (KRSL). Empirical results on in-domain and out-of-domain benchmarks show EDIT consistently outperforms standard CoT distillation and other baselines, with robust gains across model sizes and architectures. Analyses reveal that learning from logical mistakes in dual CoTs yields particularly strong benefits and that higher-quality, dual CoTs correlate with more faithful key-step reasoning, supporting the practical impact of learning key reasoning steps for scalable reasoning in SLMs.

Abstract

As Large Language Models (LLMs) scale up and gain powerful Chain-of-Thoughts (CoTs) reasoning abilities, practical resource constraints drive efforts to distill these capabilities into more compact Smaller Language Models (SLMs). We find that CoTs consist mainly of simple reasoning forms, with a small proportion () of key reasoning steps that truly impact conclusions. However, previous distillation methods typically involve supervised fine-tuning student SLMs only on correct CoTs data produced by teacher LLMs, resulting in students struggling to learn the key reasoning steps, instead imitating the teacher's reasoning forms and making errors or omissions on these steps. To address these issues, drawing an analogy to human learning, where analyzing mistakes according to correct solutions often reveals the crucial steps leading to successes or failures, we propose mistak\textbf{E}-\textbf{D}riven key reason\textbf{I}ng step distilla\textbf{T}ion (\textbf{EDIT}), a novel method that further aids SLMs learning key reasoning steps rather than mere simple fine-tuning. Firstly, to expose these crucial steps in CoTs, we design specific prompts to generate dual CoTs data with similar reasoning paths but divergent conclusions. Then, we apply the minimum edit distance algorithm on the dual CoTs data to locate these key steps and optimize the likelihood of these steps. Extensive experiments validate the effectiveness of EDIT across both in-domain and out-of-domain benchmark reasoning datasets. Further analysis shows that EDIT can generate high-quality CoTs with more correct key reasoning steps. Notably, we also explore how different mistake patterns affect performance and find that EDIT benefits more from logical errors than from knowledge or mathematical calculation errors in dual CoTs\footnote{Code can be found at \url{https://github.com/C-W-D/EDIT}}.
Paper Structure (35 sections, 7 equations, 6 figures, 25 tables)

This paper contains 35 sections, 7 equations, 6 figures, 25 tables.

Figures (6)

  • Figure 1: Examples of CoTs generated by teacher LLMs and student SLMs on our test dataset. Simply SFT leads to an "unthinking" student who imitates the teacher's reasoning forms but makes errors and omissions in key reasoning steps, where the imitated contents are highlighted in red, and the key steps are marked with boxes.
  • Figure 2: Overview of our mistake-driven key reasoning step distillation. (1) We first retain all CoTs data annotated by teacher LLMs (2) and ask teacher LLMs to generate dual CoTs data using our designed two comprehensive prompts. (3) Then we fine-tune student SLMs on both original correct and rectified-after CoTs data. Finally, we apply key reasoning step learning on the pre-tuned student SLMs by identifying the minor difference between the dual CoTs.
  • Figure 3: Examples of locating key reasoning steps in dual CoTs, where the correct CoT and the wrong CoT are dual to each other. The identified key steps in correct reasoning and wrong reasoning are respectively marked in green and red.
  • Figure 4: Ablation results on model size for four OOD datasets. The dotted line indicates the performance of the teacher LLM under the Zero-shot-CoT setting. Due to the space limitation, we present the results on the IND dataset in Appendix \ref{['appendix:ablation on model size for IND']}.
  • Figure 5: Left: Ablation results on key reasoning steps for the IND (BBH-test) and OOD (others) datasets. w/o Correct represents that students only learn key reasoning steps in wrong CoTs and w/o Wrong represents that students only learn key reasoning steps in correct CoTs. Middle: Ablation results on different student models for the IND and OOD. We compare EDIT with its variants w/o KRSL and Std-CoT. The results are reported by IND-AVG and OOD-AVG that respectively denote average accuracy on IND and OOD datasets. Right: Score distribution evaluated by GPT-4 on BBH-test. We use kernel density estimation to visualize the distribution of CoTs quality scores.
  • ...and 1 more figures