Table of Contents
Fetching ...

Skip-Thinking: Chunk-wise Chain-of-Thought Distillation Enable Smaller Language Models to Reason Better and Faster

Xiao Chen, Sihang Zhou, Ke Liang, Xiaoyu Sun, Xinwang Liu

TL;DR

The paper tackles the inefficiencies of standard CoT distillation by introducing a chunk-wise training (CWT) strategy that divides long rationales into semantically coherent chunks, reducing the token-level batch size and mitigating gradient oversmoothing. Building on CWT, skip-thinking training (STT) learns to automatically skip non-essential reasoning chunks during inference, accelerating response times while preserving accuracy. Through a chunk data generator with average chunking (AC) and a search-based chunking (SBC) variant, along with the skip data generator, the approach demonstrates improved reasoning performance and speed across multiple SLMs and reasoning tasks. Empirical results on seven benchmarks show SBC usually outperforms AC, and STT yields notable inference speedups with maintained accuracy, offering a practical path to faster and more reliable SLM reasoning. The work also discusses limitations (e.g., potential local optima in SBC) and ethical considerations related to toxicity transfer from LLMs to SLMs.

Abstract

Chain-of-thought (CoT) distillation allows a large language model (LLM) to guide a small language model (SLM) in reasoning tasks. Existing methods train the SLM to learn the long rationale in one iteration, resulting in two issues: 1) Long rationales lead to a large token-level batch size during training, making gradients of core reasoning tokens (i.e., the token will directly affect the correctness of subsequent reasoning) over-smoothed as they contribute a tiny fraction of the rationale. As a result, the SLM converges to sharp minima where it fails to grasp the reasoning logic. 2) The response is slow, as the SLM must generate a long rationale before reaching the answer. Therefore, we propose chunk-wise training (CWT), which uses a heuristic search to divide the rationale into internal semantically coherent chunks and focuses SLM on learning from only one chunk per iteration. In this way, CWT naturally isolates non-reasoning chunks that do not involve the core reasoning token (e.g., summary and transitional chunks) from the SLM learning for reasoning chunks, making the fraction of the core reasoning token increase in the corresponding iteration. Based on CWT, skip-thinking training (STT) is proposed. STT makes the SLM automatically skip non-reasoning medium chunks to reach the answer, improving reasoning speed while maintaining accuracy. We validate our approach on a variety of SLMs and multiple reasoning tasks.

Skip-Thinking: Chunk-wise Chain-of-Thought Distillation Enable Smaller Language Models to Reason Better and Faster

TL;DR

The paper tackles the inefficiencies of standard CoT distillation by introducing a chunk-wise training (CWT) strategy that divides long rationales into semantically coherent chunks, reducing the token-level batch size and mitigating gradient oversmoothing. Building on CWT, skip-thinking training (STT) learns to automatically skip non-essential reasoning chunks during inference, accelerating response times while preserving accuracy. Through a chunk data generator with average chunking (AC) and a search-based chunking (SBC) variant, along with the skip data generator, the approach demonstrates improved reasoning performance and speed across multiple SLMs and reasoning tasks. Empirical results on seven benchmarks show SBC usually outperforms AC, and STT yields notable inference speedups with maintained accuracy, offering a practical path to faster and more reliable SLM reasoning. The work also discusses limitations (e.g., potential local optima in SBC) and ethical considerations related to toxicity transfer from LLMs to SLMs.

Abstract

Chain-of-thought (CoT) distillation allows a large language model (LLM) to guide a small language model (SLM) in reasoning tasks. Existing methods train the SLM to learn the long rationale in one iteration, resulting in two issues: 1) Long rationales lead to a large token-level batch size during training, making gradients of core reasoning tokens (i.e., the token will directly affect the correctness of subsequent reasoning) over-smoothed as they contribute a tiny fraction of the rationale. As a result, the SLM converges to sharp minima where it fails to grasp the reasoning logic. 2) The response is slow, as the SLM must generate a long rationale before reaching the answer. Therefore, we propose chunk-wise training (CWT), which uses a heuristic search to divide the rationale into internal semantically coherent chunks and focuses SLM on learning from only one chunk per iteration. In this way, CWT naturally isolates non-reasoning chunks that do not involve the core reasoning token (e.g., summary and transitional chunks) from the SLM learning for reasoning chunks, making the fraction of the core reasoning token increase in the corresponding iteration. Based on CWT, skip-thinking training (STT) is proposed. STT makes the SLM automatically skip non-reasoning medium chunks to reach the answer, improving reasoning speed while maintaining accuracy. We validate our approach on a variety of SLMs and multiple reasoning tasks.

Paper Structure

This paper contains 33 sections, 6 equations, 8 figures, 14 tables, 1 algorithm.

Figures (8)

  • Figure 1: Illustration of CoT Distillation. The batch size is set to 1 as an illustrative example. The core reasoning token (like the yellow and green ball in rationale R) means that its accuracy can determine the subsequent reasoning process. 1) Superficial understanding: The large token-level batch size will cause the gradient of the core reasoning token to be over-smoothed by plenty of other non-reasoning tokens (highlighted with a gray background in R) that are similar across different rationales during backpropagation, leading to SLMs converging to a sharp minimum where SLM often makes mistakes when generating the core reasoning token. 2) Time-consuming: Generating the full R takes longer than outputting the answer A directly.
  • Figure 2: The illustration of the proposed methods. The flames indicate that the model is undergoing training, and the [thought] is a specail token that represents the SLM is thinking in mind.
  • Figure 3: A comprehensive comparison of the average inference speed and performance across different methods on all datasets using GPT2-base.
  • Figure 4: SLM performance trend when the number of chunks changes. The vertical dotted line refers to the average number of reasoning steps.
  • Figure 5: GPT2-base's performance trend when the batch size changes. Batch size is proportional to token-level batch size. Chunk means using CWT with SBC.
  • ...and 3 more figures