Table of Contents
Fetching ...

Enhancing Generalization in Chain of Thought Reasoning for Smaller Models

Maxwell J. Yin, Dingyi Jiang, Yongbing Chen, Boyu Wang, Charles Ling

TL;DR

This paper tackles the challenge that CoT reasoning quality degrades when distilling to smaller language models, shifting from generalization to memorization. It introduces PRADA, a three-part framework that combines (i) diverse CoT generation from a large teacher, (ii) prompt-learning adapters (P-Tuning) in the student, and (iii) domain-adversarial fine-tuning with a gradient-reversal mechanism to learn domain-invariant CoT features. The authors formalize an objective that jointly optimizes task accuracy and domain-invariance, leveraging $E( heta_f, heta_y, heta_d)=\frac{1}{N_s}\sum L_y^i(\theta_f,\theta_y)-\lambda\left(\frac{1}{N_s}\sum L_d^i(\theta_f,\theta_d)+\frac{1}{N_t}\sum L_d^i(\theta_f,\theta_d)\right)$ and use a gradient-reversal layer $R(q)$ with $R(q)=q$, $\frac{dR}{dq}=-I$. Experimental results on 12 datasets across four reasoning categories show PRADA consistently outperforms prior CoT distillation methods, with faster convergence and stronger cross-domain generalization, while aligning student reasoning with domain knowledge for better explainability. The work demonstrates that adversarially learned, domain-invariant CoT prompts can bridge the gap between teacher capability and compact deployment, enabling robust reasoning in smaller LLMs. These findings have practical impact for deploying explainable, generalizable CoT systems in real-world, resource-constrained settings.

Abstract

Chain-of-Thought (CoT) reasoning in smaller language models is a challenging natural language process problem yet highly desirable in many real-life applications. Existing CoT knowledge distillation methods often suffer from overly conservative memorization in smaller LLMs, leading to low generalization confidence. As fully preserving the CoT ability of teacher model is impossible, we hypothesize that adversarial CoT fine-tuning is crucial for developing smaller LLM with robust CoT generalization. To this end, we propose \textit{PRompt-Assisted Domain-Adversarial fine-tuning} (PRADA), a principled fine-tuning framework that integrates diverse CoT domains. Specifically, PRADA pioneers two CoT improvements in smaller LLM: (1) Recovering the domain-invariant feature insight which typically lost during distillation with domain adversarial fine-tuning; (2) Enhancing the domain adaptability of CoT prompt engineering by employing domain-adversarial approaches. We theoretically demonstrate the effectiveness of our approach and empirically show that it significantly outperforms the state of the arts in a wide range of tasks. Moreover, our empirical findings reveal that the smaller LLM, when leveraging PRADA, aligns closely with domain knowledge, thereby improving the explainability of our approach.

Enhancing Generalization in Chain of Thought Reasoning for Smaller Models

TL;DR

This paper tackles the challenge that CoT reasoning quality degrades when distilling to smaller language models, shifting from generalization to memorization. It introduces PRADA, a three-part framework that combines (i) diverse CoT generation from a large teacher, (ii) prompt-learning adapters (P-Tuning) in the student, and (iii) domain-adversarial fine-tuning with a gradient-reversal mechanism to learn domain-invariant CoT features. The authors formalize an objective that jointly optimizes task accuracy and domain-invariance, leveraging and use a gradient-reversal layer with , . Experimental results on 12 datasets across four reasoning categories show PRADA consistently outperforms prior CoT distillation methods, with faster convergence and stronger cross-domain generalization, while aligning student reasoning with domain knowledge for better explainability. The work demonstrates that adversarially learned, domain-invariant CoT prompts can bridge the gap between teacher capability and compact deployment, enabling robust reasoning in smaller LLMs. These findings have practical impact for deploying explainable, generalizable CoT systems in real-world, resource-constrained settings.

Abstract

Chain-of-Thought (CoT) reasoning in smaller language models is a challenging natural language process problem yet highly desirable in many real-life applications. Existing CoT knowledge distillation methods often suffer from overly conservative memorization in smaller LLMs, leading to low generalization confidence. As fully preserving the CoT ability of teacher model is impossible, we hypothesize that adversarial CoT fine-tuning is crucial for developing smaller LLM with robust CoT generalization. To this end, we propose \textit{PRompt-Assisted Domain-Adversarial fine-tuning} (PRADA), a principled fine-tuning framework that integrates diverse CoT domains. Specifically, PRADA pioneers two CoT improvements in smaller LLM: (1) Recovering the domain-invariant feature insight which typically lost during distillation with domain adversarial fine-tuning; (2) Enhancing the domain adaptability of CoT prompt engineering by employing domain-adversarial approaches. We theoretically demonstrate the effectiveness of our approach and empirically show that it significantly outperforms the state of the arts in a wide range of tasks. Moreover, our empirical findings reveal that the smaller LLM, when leveraging PRADA, aligns closely with domain knowledge, thereby improving the explainability of our approach.
Paper Structure (23 sections, 11 equations, 6 figures, 2 tables, 1 algorithm)

This paper contains 23 sections, 11 equations, 6 figures, 2 tables, 1 algorithm.

Figures (6)

  • Figure 1: Transition from memorization to generalization, for GPT-3 families fine-tuned on four different domains (datasets detailed in \ref{['Experiments']}. The source accuracy is tested on the original domain, and the target accuracy is tested on different domains.
  • Figure 2: Case study of model CoT reasoning ability decay in the knowledge distillation process from GPT-3 175B to GPT-3 6.7B.The reduction in model parameters causes the student model's CoT generalization ability to diminish, shifting towards memorization. Consequently, the student model exhibits bad CoT generalization on unseen data.
  • Figure 3: Illustration of our method Prompt-Assisted Domain-Adversarial fine-tuning (PRADA). Firstly, the teacher model is prompted to generate diverse CoT reasoning responses using a Zero-Shot-CoT approach. Secondly, an adapter is inserted into the student LLM for P-Tuning, which refines the domain-agnostic knowledge of the source domain. Thirdly, the knowledge distillation process incorporates data from both the labeled source domain and the unlabeled target domain. The proposed architecture includes an LLM (green) and a language modeling head (blue), which together form a standard feed-forward architecture. Unsupervised domain adversarial fine-tuning is achieved by adding a domain classifier (red).
  • Figure 5: The t-SNE visualization of embeddings of Llama 3-8B using our method PRADA or not on source GSM8K for 4 arithmetic targets.
  • Figure 6: The generalization ability of models with and without PRADA (ours) finetuned on source domain GSM8K. The lower limit of the same color curve is the Vanilla model's accuracy, and the upper limit is the PRADA method's accuracy.
  • ...and 1 more figures