Table of Contents
Fetching ...

Investigating Mysteries of CoT-Augmented Distillation

Somin Wadhwa, Silvio Amir, Byron C. Wallace

TL;DR

This work asks: Why and how does this additional training signal help in model distillation?

Abstract

Eliciting "chain of thought" (CoT) rationales -- sequences of token that convey a "reasoning" process -- has been shown to consistently improve LLM performance on tasks like question answering. More recent efforts have shown that such rationales can also be used for model distillation: Including CoT sequences (elicited from a large "teacher" model) in addition to target labels when fine-tuning a small student model yields (often substantial) improvements. In this work we ask: Why and how does this additional training signal help in model distillation? We perform ablations to interrogate this, and report some potentially surprising results. Specifically: (1) Placing CoT sequences after labels (rather than before) realizes consistently better downstream performance -- this means that no student "reasoning" is necessary at test time to realize gains. (2) When rationales are appended in this way, they need not be coherent reasoning sequences to yield improvements; performance increases are robust to permutations of CoT tokens, for example. In fact, (3) a small number of key tokens are sufficient to achieve improvements equivalent to those observed when full rationales are used in model distillation.

Investigating Mysteries of CoT-Augmented Distillation

TL;DR

This work asks: Why and how does this additional training signal help in model distillation?

Abstract

Eliciting "chain of thought" (CoT) rationales -- sequences of token that convey a "reasoning" process -- has been shown to consistently improve LLM performance on tasks like question answering. More recent efforts have shown that such rationales can also be used for model distillation: Including CoT sequences (elicited from a large "teacher" model) in addition to target labels when fine-tuning a small student model yields (often substantial) improvements. In this work we ask: Why and how does this additional training signal help in model distillation? We perform ablations to interrogate this, and report some potentially surprising results. Specifically: (1) Placing CoT sequences after labels (rather than before) realizes consistently better downstream performance -- this means that no student "reasoning" is necessary at test time to realize gains. (2) When rationales are appended in this way, they need not be coherent reasoning sequences to yield improvements; performance increases are robust to permutations of CoT tokens, for example. In fact, (3) a small number of key tokens are sufficient to achieve improvements equivalent to those observed when full rationales are used in model distillation.
Paper Structure (22 sections, 1 equation, 8 figures, 3 tables)

This paper contains 22 sections, 1 equation, 8 figures, 3 tables.

Figures (8)

  • Figure 1: For RQ1, we investigate augmenting CoT rationales obtained by very large (teacher) language models like Mistral, after the target labels. In doing so, we inject the same CoT reasoning ability during supervised fine-tuning (SFT) but do not condition generation of target label on the CoT itself at inference time.
  • Figure 2: TunedLens belrose2023eliciting visualizations on GPT-2 variants fine-tuned without CoT rationales (left), and with them pre-pended (middle) and appended (right). Augmenting distillation with CoT results in models that are more confident in labels earlier on. Models trained with rationales following labels are especially confident.
  • Figure 3: Performance of GPT-2 with constant number of $<$ unk$>$ tokens prepended to the target label.
  • Figure 4: Comparison of model performance while successively reducing the amount of available information in a CoT rationale through masking.
  • Figure 5: Comparison of Attribution Methods: Left side we have automated extraction via Integrated Gradients while the right side displays manually annotated words perceived by human annotators to be relevant.
  • ...and 3 more figures