Table of Contents
Fetching ...

Knowledge distillation through geometry-aware representational alignment

Prajjwal Bhattarai, Mohammad Amjad, Dmytro Zhylko, Tuka Alhanai

TL;DR

The paper addresses geometry preservation in knowledge distillation by critically evaluating projection-based and CKA-based feature alignments. It introduces geometry-centric losses—Procrustes distance and Feature Gram Matrix distance—and proves theoretical connections showing $ oxed{ ext{D}_{ ext{P}} = 0 \\Leftrightarrow \\mathcal{D}_{FG} = 0} $, while highlighting limitations of $ ext{D}_{CKA} $ and unrestricted linear projections. Through synthetic experiments and real-language-model tasks (classification with BERT and instruction-following with OPT), the authors demonstrate that Procrustes and FG-based distillation reliably preserve geometric structure and yield performance gains, often outperforming CKA baselines by notable margins. The findings suggest that incorporating representational geometry into KD can improve knowledge transfer efficiency and robustness across model sizes and tasks, with practical impact for scalable language-model deployment.

Abstract

Knowledge distillation is a common paradigm for transferring capabilities from larger models to smaller ones. While traditional distillation methods leverage a probabilistic divergence over the output of the teacher and student models, feature-based distillation methods often minimize variants of Euclidean norms between the hidden layer representations. The main goal is for the student to mimic the structure of the feature space of the teacher. In this work, we theoretically show that existing feature distillation methods, such as projection based mean squared loss or Centered Kernel Alignment (CKA), cannot capture the feature structure, even under zero loss. We then motivate the use of Procrustes distance and the Frobenius norm of Feature Gram Matrix, distances already common in the context of measuring representational alignment, as distillation losses. We show that feature distillation through our method showcases statistically significant improvement in distillation performance across language models families (BERT and OPT) in classification and instruction-following tasks by up to 2 percentage points, showcasing the potential of integrating feature geometry into existing distillation methods.

Knowledge distillation through geometry-aware representational alignment

TL;DR

The paper addresses geometry preservation in knowledge distillation by critically evaluating projection-based and CKA-based feature alignments. It introduces geometry-centric losses—Procrustes distance and Feature Gram Matrix distance—and proves theoretical connections showing , while highlighting limitations of and unrestricted linear projections. Through synthetic experiments and real-language-model tasks (classification with BERT and instruction-following with OPT), the authors demonstrate that Procrustes and FG-based distillation reliably preserve geometric structure and yield performance gains, often outperforming CKA baselines by notable margins. The findings suggest that incorporating representational geometry into KD can improve knowledge transfer efficiency and robustness across model sizes and tasks, with practical impact for scalable language-model deployment.

Abstract

Knowledge distillation is a common paradigm for transferring capabilities from larger models to smaller ones. While traditional distillation methods leverage a probabilistic divergence over the output of the teacher and student models, feature-based distillation methods often minimize variants of Euclidean norms between the hidden layer representations. The main goal is for the student to mimic the structure of the feature space of the teacher. In this work, we theoretically show that existing feature distillation methods, such as projection based mean squared loss or Centered Kernel Alignment (CKA), cannot capture the feature structure, even under zero loss. We then motivate the use of Procrustes distance and the Frobenius norm of Feature Gram Matrix, distances already common in the context of measuring representational alignment, as distillation losses. We show that feature distillation through our method showcases statistically significant improvement in distillation performance across language models families (BERT and OPT) in classification and instruction-following tasks by up to 2 percentage points, showcasing the potential of integrating feature geometry into existing distillation methods.

Paper Structure

This paper contains 26 sections, 7 theorems, 20 equations, 10 figures, 3 tables, 1 algorithm.

Key Result

Theorem 1

Let $\mathbf{R_t}$ and $\mathbf{R_s}$ be centered, unit norm matrix of feature activations, such that $\mathcal{D}_{FG} =0$ and $\mathcal{D}_{CKA}=0$. For any $\epsilon \in [0,1]$, we can construct another set of points $\mathbf{\Tilde{R}_t}$ such that $\mathcal{D}_{CKA} (\mathbf{\Tilde{R}_t}, \math

Figures (10)

  • Figure 1: A simplified illustration of the phenomenon prescribed by Theorem 1. (a): $n$ vectors in $d_t$ dimensions lie in exactly two configurations that are antiparallel to each other. (b): A subset of those $n$ vectors from (a) are perturbed along a distinct orthogonal direction among the $d_t$ possible ones. (c): an exact replication of (a) in $d_s < d_t$ dimensions. Although the feature geometries differ, CKA computed with respect to (c) fails to differentiate between (a) and (b).
  • Figure 2: Dynamics of Procrustes distance throughout the synthetic training process when student vectors are initialized from a random projection
  • Figure 3: Dynamics of the norm of the difference in Feature Gram matrices throughout the synthetic training process when student vectors are initialized from a random projection
  • Figure 4: Dynamics of the learned linear projection loss throughout the synthetic training process when student vectors are initialized from a random projection.
  • Figure 5: Dynamics of the number of approximate orthogonal vectors through the synthetic training process when the student vectors are randomly initialized
  • ...and 5 more figures

Theorems & Definitions (17)

  • Theorem 1
  • proof
  • Remark
  • Theorem 2
  • proof
  • Remark
  • Theorem 3
  • proof
  • Lemma 1
  • proof
  • ...and 7 more