Table of Contents
Fetching ...

SIKeD: Self-guided Iterative Knowledge Distillation for mathematical reasoning

Shivam Adarsh, Kumar Shridhar, Caglar Gulcehre, Nicholas Monath, Mrinmaya Sachan

TL;DR

This work proposes a distillation method SIKeD (Self-guided Iterative Knowledge Distillation for mathematical reasoning), where the large language model teaches the smaller model to approach a task using different strategies and the smaller model uses its self-generated on-policy outputs to choose the most suitable strategy for the given task.

Abstract

Large Language Models (LLMs) can transfer their reasoning skills to smaller models by teaching them to generate the intermediate reasoning process required to solve multistep reasoning tasks. While LLMs can accurately solve reasoning tasks through a variety of strategies, even without fine-tuning, smaller models are not expressive enough to fit the LLMs distribution on all strategies when distilled and tend to prioritize one strategy over the others. This reliance on one strategy poses a challenge for smaller models when attempting to solve reasoning tasks that may be difficult with their preferred strategy. To address this, we propose a distillation method SIKeD (Self-guided Iterative Knowledge Distillation for mathematical reasoning), where the LLM teaches the smaller model to approach a task using different strategies and the smaller model uses its self-generated on-policy outputs to choose the most suitable strategy for the given task. The training continues in a self-guided iterative manner, where for each training iteration, a decision is made on how to combine the LLM data with the self-generated outputs. Unlike traditional distillation methods, SIKeD allows the smaller model to learn which strategy is suitable for a given task while continuously learning to solve a task using different strategies. Our experiments on various mathematical reasoning datasets show that SIKeD significantly outperforms traditional distillation techniques across smaller models of different sizes. Our code is available at: https://github.com/kumar-shridhar/SIKeD

SIKeD: Self-guided Iterative Knowledge Distillation for mathematical reasoning

TL;DR

This work proposes a distillation method SIKeD (Self-guided Iterative Knowledge Distillation for mathematical reasoning), where the large language model teaches the smaller model to approach a task using different strategies and the smaller model uses its self-generated on-policy outputs to choose the most suitable strategy for the given task.

Abstract

Large Language Models (LLMs) can transfer their reasoning skills to smaller models by teaching them to generate the intermediate reasoning process required to solve multistep reasoning tasks. While LLMs can accurately solve reasoning tasks through a variety of strategies, even without fine-tuning, smaller models are not expressive enough to fit the LLMs distribution on all strategies when distilled and tend to prioritize one strategy over the others. This reliance on one strategy poses a challenge for smaller models when attempting to solve reasoning tasks that may be difficult with their preferred strategy. To address this, we propose a distillation method SIKeD (Self-guided Iterative Knowledge Distillation for mathematical reasoning), where the LLM teaches the smaller model to approach a task using different strategies and the smaller model uses its self-generated on-policy outputs to choose the most suitable strategy for the given task. The training continues in a self-guided iterative manner, where for each training iteration, a decision is made on how to combine the LLM data with the self-generated outputs. Unlike traditional distillation methods, SIKeD allows the smaller model to learn which strategy is suitable for a given task while continuously learning to solve a task using different strategies. Our experiments on various mathematical reasoning datasets show that SIKeD significantly outperforms traditional distillation techniques across smaller models of different sizes. Our code is available at: https://github.com/kumar-shridhar/SIKeD

Paper Structure

This paper contains 31 sections, 6 equations, 11 figures, 1 table, 1 algorithm.

Figures (11)

  • Figure 1: Histogram of strategy choices for the LLM and the smaller model.LLM tends to select several reasoning strategies, but the smaller model is biased towards one strategy. The comparison was done on 1000 data points randomly sampled from the GSM8K train set.
  • Figure 2: Alignment of the smaller model's strategy distribution with the LLM over iterations. Each subplot represents an iteration in the training process, showing the probability distributions over reasoning strategies: PoT, L2M, and CoT. The blue bars depict the LLM's distribution $P_{L}$, while the orange bars represent the smaller model's distribution $P_{SM}$, which is biased towards CoT. The green bars show the training data distribution $P_{\text{train}}^{(t)}$, a mixture of $P_{L}$ and $P_{SM}$ weighted by the mixing rate $\alpha$. As $\alpha$ decreases over iterations (from 0.90 to 0.20), $P_{\text{train}}^{(t)}$ shifts from being similar to the LLM's distribution towards the smaller model's distribution. The KL divergence between the training data and the smaller model distributions decreases accordingly, indicating increased similarity.
  • Figure 3: Accuracy comparison between single distillation strategies of CoT, PoT, and L2M with SIKeD biased training using the same strategy using the Gemma 7B model.
  • Figure 4: Iterative accuracy comparison for the Gemma 2B model across all datasets. The process is stopped when the gains diminish or when it is no longer cost effective to continue.
  • Figure 5: Strategy distribution over iterations for GSM8K dataset using SmolLM 1.7B model.
  • ...and 6 more figures