Table of Contents
Fetching ...

Keypoint-based Progressive Chain-of-Thought Distillation for LLMs

Kaituo Feng, Changsheng Li, Xiaolu Zhang, Jun Zhou, Ye Yuan, Guoren Wang

TL;DR

This paper tackles two weaknesses of existing chain-of-thought distillation: treating all rationale tokens equally and failing to learn reasoning steps in a progressive, easy-to-hard order. It proposes KPOD, which combines a token-weighted masking mechanism for keypoint token mimicry with an in-rationale curriculum inspired by curriculum learning, guided by a submodular optimization-based scheduling and a stage-wise distillation loss. The approach yields large performance gains over strong baselines on GSM8K, ASDiv, SVAMP, and CommonsenseQA, and demonstrates robust OOD generalization and interpretable token-focus visualizations. Overall, KPOD offers a principled, scalable framework for transferring reasoning from large LLMs to smaller models with improved accuracy and generalization.

Abstract

Chain-of-thought distillation is a powerful technique for transferring reasoning abilities from large language models (LLMs) to smaller student models. Previous methods typically require the student to mimic the step-by-step rationale produced by LLMs, often facing the following challenges: (i) Tokens within a rationale vary in significance, and treating them equally may fail to accurately mimic keypoint tokens, leading to reasoning errors. (ii) They usually distill knowledge by consistently predicting all the steps in a rationale, which falls short in distinguishing the learning order of step generation. This diverges from the human cognitive progression of starting with easy tasks and advancing to harder ones, resulting in sub-optimal outcomes. To this end, we propose a unified framework, called KPOD, to address these issues. Specifically, we propose a token weighting module utilizing mask learning to encourage accurate mimicry of keypoint tokens by the student during distillation. Besides, we develop an in-rationale progressive distillation strategy, starting with training the student to generate the final reasoning steps and gradually extending to cover the entire rationale. To accomplish this, a weighted token generation loss is proposed to assess step reasoning difficulty, and a value function is devised to schedule the progressive distillation by considering both step difficulty and question diversity. Extensive experiments on four reasoning benchmarks illustrate our KPOD outperforms previous methods by a large margin.

Keypoint-based Progressive Chain-of-Thought Distillation for LLMs

TL;DR

This paper tackles two weaknesses of existing chain-of-thought distillation: treating all rationale tokens equally and failing to learn reasoning steps in a progressive, easy-to-hard order. It proposes KPOD, which combines a token-weighted masking mechanism for keypoint token mimicry with an in-rationale curriculum inspired by curriculum learning, guided by a submodular optimization-based scheduling and a stage-wise distillation loss. The approach yields large performance gains over strong baselines on GSM8K, ASDiv, SVAMP, and CommonsenseQA, and demonstrates robust OOD generalization and interpretable token-focus visualizations. Overall, KPOD offers a principled, scalable framework for transferring reasoning from large LLMs to smaller models with improved accuracy and generalization.

Abstract

Chain-of-thought distillation is a powerful technique for transferring reasoning abilities from large language models (LLMs) to smaller student models. Previous methods typically require the student to mimic the step-by-step rationale produced by LLMs, often facing the following challenges: (i) Tokens within a rationale vary in significance, and treating them equally may fail to accurately mimic keypoint tokens, leading to reasoning errors. (ii) They usually distill knowledge by consistently predicting all the steps in a rationale, which falls short in distinguishing the learning order of step generation. This diverges from the human cognitive progression of starting with easy tasks and advancing to harder ones, resulting in sub-optimal outcomes. To this end, we propose a unified framework, called KPOD, to address these issues. Specifically, we propose a token weighting module utilizing mask learning to encourage accurate mimicry of keypoint tokens by the student during distillation. Besides, we develop an in-rationale progressive distillation strategy, starting with training the student to generate the final reasoning steps and gradually extending to cover the entire rationale. To accomplish this, a weighted token generation loss is proposed to assess step reasoning difficulty, and a value function is devised to schedule the progressive distillation by considering both step difficulty and question diversity. Extensive experiments on four reasoning benchmarks illustrate our KPOD outperforms previous methods by a large margin.
Paper Structure (20 sections, 2 theorems, 22 equations, 4 figures, 4 tables)

This paper contains 20 sections, 2 theorems, 22 equations, 4 figures, 4 tables.

Key Result

Proposition 3.1

The optimization of $\max_{S(t)} F(S(t))$ subject to the knapsack constraint $\Delta H(S(t)) \leq \Delta D(t)$ can be approximately solved in $O(n\epsilon^{-1}\log\epsilon^{-1})$ time complexity with a $\frac{1}{2}-\epsilon$ approximation ratio guarantee, where $n$ represents the scale of the data.

Figures (4)

  • Figure 1: An illustration of our KPOD framework. KPOD first determines the keypoint tokens for distillation through designing a rationale token weighting module based on mask learning. Then, an in-rationale progressive distillation strategy is devised to organize the learning order within rationale, so as to enable the student to acquire the reasoning capabilities in an easy-to-hard manner.
  • Figure 2: Visualizations of token significance weights produced by the weight generator. The intensity of red corresponds to the significance weight assigned to each token, with a deeper red indicating higher weight.
  • Figure 3: Parameter sensitivity study of $\alpha$ and $\beta$.
  • Figure 4: Parameter sensitivity study of $p$, $K$ and $C_0$ on GSM8K.

Theorems & Definitions (4)

  • Proposition 3.1
  • Theorem 4.1
  • proof
  • proof