Table of Contents
Fetching ...

Towards Understanding Distilled Reasoning Models: A Representational Approach

David D. Baek, Max Tegmark

TL;DR

This work investigates how distillation reshapes reasoning features in large language models by training a sparse crosscoder on Qwen-series models and their fine-tuned variants. It identifies unique reasoning feature directions in distilled models, including self-reflection and computation verification, and demonstrates their causal role through ablation and steering experiments. The study also reveals that larger distilled models develop more structured representations, as evidenced by improved parallelogram-geometry metrics, linking model size, distillation, and representation organization. Overall, the findings advance transparency and reliability in AI systems by elucidating how distillation alters internal reasoning features and their geometry.

Abstract

In this paper, we investigate how model distillation impacts the development of reasoning features in large language models (LLMs). To explore this, we train a crosscoder on Qwen-series models and their fine-tuned variants. Our results suggest that the crosscoder learns features corresponding to various types of reasoning, including self-reflection and computation verification. Moreover, we observe that distilled models contain unique reasoning feature directions, which could be used to steer the model into over-thinking or incisive-thinking mode. In particular, we perform analysis on four specific reasoning categories: (a) self-reflection, (b) deductive reasoning, (c) alternative reasoning, and (d) contrastive reasoning. Finally, we examine the changes in feature geometry resulting from the distillation process and find indications that larger distilled models may develop more structured representations, which correlate with enhanced distillation performance. By providing insights into how distillation modifies the model, our study contributes to enhancing the transparency and reliability of AI systems.

Towards Understanding Distilled Reasoning Models: A Representational Approach

TL;DR

This work investigates how distillation reshapes reasoning features in large language models by training a sparse crosscoder on Qwen-series models and their fine-tuned variants. It identifies unique reasoning feature directions in distilled models, including self-reflection and computation verification, and demonstrates their causal role through ablation and steering experiments. The study also reveals that larger distilled models develop more structured representations, as evidenced by improved parallelogram-geometry metrics, linking model size, distillation, and representation organization. Overall, the findings advance transparency and reliability in AI systems by elucidating how distillation alters internal reasoning features and their geometry.

Abstract

In this paper, we investigate how model distillation impacts the development of reasoning features in large language models (LLMs). To explore this, we train a crosscoder on Qwen-series models and their fine-tuned variants. Our results suggest that the crosscoder learns features corresponding to various types of reasoning, including self-reflection and computation verification. Moreover, we observe that distilled models contain unique reasoning feature directions, which could be used to steer the model into over-thinking or incisive-thinking mode. In particular, we perform analysis on four specific reasoning categories: (a) self-reflection, (b) deductive reasoning, (c) alternative reasoning, and (d) contrastive reasoning. Finally, we examine the changes in feature geometry resulting from the distillation process and find indications that larger distilled models may develop more structured representations, which correlate with enhanced distillation performance. By providing insights into how distillation modifies the model, our study contributes to enhancing the transparency and reliability of AI systems.

Paper Structure

This paper contains 11 sections, 10 equations, 9 figures, 1 table.

Figures (9)

  • Figure 1: (Left) Average normalized relative norm across all features for base models of various sizes. (Right) Distribution of normalized relative norm for Qwen-14b crosscoder features.
  • Figure 2: Ablation Experiment: Histogram depicting the average logit change in both base and distilled models as a result of ablating features with (a) NRN $>0.5$, and (b) firing frequency in top $k\in [0.5, 1, 2, 5, 10, 20]\%$.
  • Figure 3: Distilled Model's behavior steered into incisive thinking mode.
  • Figure 4: Cumulative fraction as a function of parallelogram loss for different models and function classes. Distilled model's representations tend to become more structured as the model scales.
  • Figure 5: Parallelogram loss with activations PCA-ed into 2D.
  • ...and 4 more figures