Knowledge distillation through geometry-aware representational alignment
Prajjwal Bhattarai, Mohammad Amjad, Dmytro Zhylko, Tuka Alhanai
TL;DR
The paper addresses geometry preservation in knowledge distillation by critically evaluating projection-based and CKA-based feature alignments. It introduces geometry-centric losses—Procrustes distance and Feature Gram Matrix distance—and proves theoretical connections showing $ oxed{ ext{D}_{ ext{P}} = 0 \\Leftrightarrow \\mathcal{D}_{FG} = 0} $, while highlighting limitations of $ ext{D}_{CKA} $ and unrestricted linear projections. Through synthetic experiments and real-language-model tasks (classification with BERT and instruction-following with OPT), the authors demonstrate that Procrustes and FG-based distillation reliably preserve geometric structure and yield performance gains, often outperforming CKA baselines by notable margins. The findings suggest that incorporating representational geometry into KD can improve knowledge transfer efficiency and robustness across model sizes and tasks, with practical impact for scalable language-model deployment.
Abstract
Knowledge distillation is a common paradigm for transferring capabilities from larger models to smaller ones. While traditional distillation methods leverage a probabilistic divergence over the output of the teacher and student models, feature-based distillation methods often minimize variants of Euclidean norms between the hidden layer representations. The main goal is for the student to mimic the structure of the feature space of the teacher. In this work, we theoretically show that existing feature distillation methods, such as projection based mean squared loss or Centered Kernel Alignment (CKA), cannot capture the feature structure, even under zero loss. We then motivate the use of Procrustes distance and the Frobenius norm of Feature Gram Matrix, distances already common in the context of measuring representational alignment, as distillation losses. We show that feature distillation through our method showcases statistically significant improvement in distillation performance across language models families (BERT and OPT) in classification and instruction-following tasks by up to 2 percentage points, showcasing the potential of integrating feature geometry into existing distillation methods.
