Enhancing Generalization in Chain of Thought Reasoning for Smaller Models
Maxwell J. Yin, Dingyi Jiang, Yongbing Chen, Boyu Wang, Charles Ling
TL;DR
This paper tackles the challenge that CoT reasoning quality degrades when distilling to smaller language models, shifting from generalization to memorization. It introduces PRADA, a three-part framework that combines (i) diverse CoT generation from a large teacher, (ii) prompt-learning adapters (P-Tuning) in the student, and (iii) domain-adversarial fine-tuning with a gradient-reversal mechanism to learn domain-invariant CoT features. The authors formalize an objective that jointly optimizes task accuracy and domain-invariance, leveraging $E( heta_f, heta_y, heta_d)=\frac{1}{N_s}\sum L_y^i(\theta_f,\theta_y)-\lambda\left(\frac{1}{N_s}\sum L_d^i(\theta_f,\theta_d)+\frac{1}{N_t}\sum L_d^i(\theta_f,\theta_d)\right)$ and use a gradient-reversal layer $R(q)$ with $R(q)=q$, $\frac{dR}{dq}=-I$. Experimental results on 12 datasets across four reasoning categories show PRADA consistently outperforms prior CoT distillation methods, with faster convergence and stronger cross-domain generalization, while aligning student reasoning with domain knowledge for better explainability. The work demonstrates that adversarially learned, domain-invariant CoT prompts can bridge the gap between teacher capability and compact deployment, enabling robust reasoning in smaller LLMs. These findings have practical impact for deploying explainable, generalizable CoT systems in real-world, resource-constrained settings.
Abstract
Chain-of-Thought (CoT) reasoning in smaller language models is a challenging natural language process problem yet highly desirable in many real-life applications. Existing CoT knowledge distillation methods often suffer from overly conservative memorization in smaller LLMs, leading to low generalization confidence. As fully preserving the CoT ability of teacher model is impossible, we hypothesize that adversarial CoT fine-tuning is crucial for developing smaller LLM with robust CoT generalization. To this end, we propose \textit{PRompt-Assisted Domain-Adversarial fine-tuning} (PRADA), a principled fine-tuning framework that integrates diverse CoT domains. Specifically, PRADA pioneers two CoT improvements in smaller LLM: (1) Recovering the domain-invariant feature insight which typically lost during distillation with domain adversarial fine-tuning; (2) Enhancing the domain adaptability of CoT prompt engineering by employing domain-adversarial approaches. We theoretically demonstrate the effectiveness of our approach and empirically show that it significantly outperforms the state of the arts in a wide range of tasks. Moreover, our empirical findings reveal that the smaller LLM, when leveraging PRADA, aligns closely with domain knowledge, thereby improving the explainability of our approach.
