Improve Student's Reasoning Generalizability through Cascading Decomposed CoTs Distillation
Chengwei Dai, Kun Li, Wei Zhou, Songlin Hu
TL;DR
CasCoD tackles the poor OOD generalization of standard CoT distillation by decomposing the learning objective into two cascaded steps: learning rationales without the answer and then learning the answer from the question-rationale pair. It defines dedicated losses for rationale and answer, combining them as $\\mathcal{L}_{\\text{CasCoD}}=(1-\\alpha)\\mathcal{L}_{\\text{rationale}}+\\alpha\\mathcal{L}_{\\text{answer}}$, and uses a two-stage inference pipeline that first generates a rationale before predicting the final answer. Across in-domain and out-of-domain benchmarks, CasCoD consistently outperforms Std-CoT and other distillation baselines, demonstrating robustness to model size and training data quantity while improving reasoning faithfulness. The approach reduces reliance on spurious question–answer correlations, enabling more generalizable chain-of-thought reasoning in smaller models with practical implications for scalable AI alignment and robust reasoning. $\text{CasCoD}$ thus provides a principled, data-efficient path to transferring complex reasoning capabilities from large LLMs to smaller, more affordable models.
Abstract
Large language models (LLMs) exhibit enhanced reasoning at larger scales, driving efforts to distill these capabilities into smaller models via teacher-student learning. Previous works simply fine-tune student models on teachers' generated Chain-of-Thoughts (CoTs) data. Although these methods enhance in-domain (IND) reasoning performance, they struggle to generalize to out-of-domain (OOD) tasks. We believe that the widespread spurious correlations between questions and answers may lead the model to preset a specific answer which restricts the diversity and generalizability of its reasoning process. In this paper, we propose Cascading Decomposed CoTs Distillation (CasCoD) to address these issues by decomposing the traditional single-step learning process into two cascaded learning steps. Specifically, by restructuring the training objectives -- removing the answer from outputs and concatenating the question with the rationale as input -- CasCoD's two-step learning process ensures that students focus on learning rationales without interference from the preset answers, thus improving reasoning generalizability. Extensive experiments demonstrate the effectiveness of CasCoD on both IND and OOD benchmark reasoning datasets. Code can be found at https://github.com/C-W-D/CasCoD.
