Table of Contents
Fetching ...

Improve Student's Reasoning Generalizability through Cascading Decomposed CoTs Distillation

Chengwei Dai, Kun Li, Wei Zhou, Songlin Hu

TL;DR

CasCoD tackles the poor OOD generalization of standard CoT distillation by decomposing the learning objective into two cascaded steps: learning rationales without the answer and then learning the answer from the question-rationale pair. It defines dedicated losses for rationale and answer, combining them as $\\mathcal{L}_{\\text{CasCoD}}=(1-\\alpha)\\mathcal{L}_{\\text{rationale}}+\\alpha\\mathcal{L}_{\\text{answer}}$, and uses a two-stage inference pipeline that first generates a rationale before predicting the final answer. Across in-domain and out-of-domain benchmarks, CasCoD consistently outperforms Std-CoT and other distillation baselines, demonstrating robustness to model size and training data quantity while improving reasoning faithfulness. The approach reduces reliance on spurious question–answer correlations, enabling more generalizable chain-of-thought reasoning in smaller models with practical implications for scalable AI alignment and robust reasoning. $\text{CasCoD}$ thus provides a principled, data-efficient path to transferring complex reasoning capabilities from large LLMs to smaller, more affordable models.

Abstract

Large language models (LLMs) exhibit enhanced reasoning at larger scales, driving efforts to distill these capabilities into smaller models via teacher-student learning. Previous works simply fine-tune student models on teachers' generated Chain-of-Thoughts (CoTs) data. Although these methods enhance in-domain (IND) reasoning performance, they struggle to generalize to out-of-domain (OOD) tasks. We believe that the widespread spurious correlations between questions and answers may lead the model to preset a specific answer which restricts the diversity and generalizability of its reasoning process. In this paper, we propose Cascading Decomposed CoTs Distillation (CasCoD) to address these issues by decomposing the traditional single-step learning process into two cascaded learning steps. Specifically, by restructuring the training objectives -- removing the answer from outputs and concatenating the question with the rationale as input -- CasCoD's two-step learning process ensures that students focus on learning rationales without interference from the preset answers, thus improving reasoning generalizability. Extensive experiments demonstrate the effectiveness of CasCoD on both IND and OOD benchmark reasoning datasets. Code can be found at https://github.com/C-W-D/CasCoD.

Improve Student's Reasoning Generalizability through Cascading Decomposed CoTs Distillation

TL;DR

CasCoD tackles the poor OOD generalization of standard CoT distillation by decomposing the learning objective into two cascaded steps: learning rationales without the answer and then learning the answer from the question-rationale pair. It defines dedicated losses for rationale and answer, combining them as , and uses a two-stage inference pipeline that first generates a rationale before predicting the final answer. Across in-domain and out-of-domain benchmarks, CasCoD consistently outperforms Std-CoT and other distillation baselines, demonstrating robustness to model size and training data quantity while improving reasoning faithfulness. The approach reduces reliance on spurious question–answer correlations, enabling more generalizable chain-of-thought reasoning in smaller models with practical implications for scalable AI alignment and robust reasoning. thus provides a principled, data-efficient path to transferring complex reasoning capabilities from large LLMs to smaller, more affordable models.

Abstract

Large language models (LLMs) exhibit enhanced reasoning at larger scales, driving efforts to distill these capabilities into smaller models via teacher-student learning. Previous works simply fine-tune student models on teachers' generated Chain-of-Thoughts (CoTs) data. Although these methods enhance in-domain (IND) reasoning performance, they struggle to generalize to out-of-domain (OOD) tasks. We believe that the widespread spurious correlations between questions and answers may lead the model to preset a specific answer which restricts the diversity and generalizability of its reasoning process. In this paper, we propose Cascading Decomposed CoTs Distillation (CasCoD) to address these issues by decomposing the traditional single-step learning process into two cascaded learning steps. Specifically, by restructuring the training objectives -- removing the answer from outputs and concatenating the question with the rationale as input -- CasCoD's two-step learning process ensures that students focus on learning rationales without interference from the preset answers, thus improving reasoning generalizability. Extensive experiments demonstrate the effectiveness of CasCoD on both IND and OOD benchmark reasoning datasets. Code can be found at https://github.com/C-W-D/CasCoD.
Paper Structure (51 sections, 6 equations, 8 figures, 17 tables)

This paper contains 51 sections, 6 equations, 8 figures, 17 tables.

Figures (8)

  • Figure 1: (a) Empirical results of standard CoT distillation (Std-CoT) and directly fine-tuning on answer labels without CoTs (Answer SFT) on one in-domain (BBH-test) and the other four out-of-domain benchmark reasoning datasets. (b) In the given example, the semantic similarity between "swimsuit" in the question and "swim" in the answer demonstrates a high level of match, which could allow the model to predict the answer using simple keyword matching or certain rules.
  • Figure 2: Overview of our proposed method Cascading Decomposed CoTs Distillation (CasCoD). Different from the standard CoTs distillation, we decompose the single CoT learning step into two comprehensive learning steps including the rationale learning step and the answer learning step, and then learn them in a cascaded way.
  • Figure 3: Comparison between two-step and single-step training implementations of CasCoD.
  • Figure 4: Ablation study on model size for four OOD datasets. The dotted line indicates the performance of the teacher LLM under the Zero-shot-CoT setting. The results in IND dataset can be found in Appendix \ref{['appendix:ablation-model-size-ind']}.
  • Figure 5: Ablation study on training data size for four OOD datasets. The dotted line indicates the performance of fine-tuning the student models by standard CoTs distillation using the full set (100% of) BBH-train dataset. The results in IND dataset can be found in Appendix \ref{['appendix:ablation-data-size-ind']}.
  • ...and 3 more figures