Table of Contents
Fetching ...

Rethinking Centered Kernel Alignment in Knowledge Distillation

Zikai Zhou, Yunhang Shen, Shitong Shao, Linrui Gong, Shaohui Lin

TL;DR

This work reexamines Centered Kernel Alignment (CKA) for knowledge distillation by revealing that CKA approximates the upper bound of the Maximum Mean Discrepancy (MMD) plus a constant term, which motivates simpler and more effective distillation mechanisms. It introduces Relation-based Centered Kernel Alignment (RCKA) for robust, scalable alignment of high-order representations across teacher and student networks, and Patch-based CKA (PCKA) to adapt distillation to object detection by operating on patch-level Gram matrices. The authors validate their approach on CIFAR-100, ImageNet-1k, and MS-COCO, achieving state-of-the-art or competitive results while reducing computational overhead relative to prior CKA-centric methods. They also provide extensive ablations and visualizations to support theoretical claims and demonstrate the benefits of patching and channel-wise averaging, with code released for reproducibility and extension. Overall, the paper offers a principled, scalable pathway to leverage CKA in KD across diverse vision tasks.

Abstract

Knowledge distillation has emerged as a highly effective method for bridging the representation discrepancy between large-scale models and lightweight models. Prevalent approaches involve leveraging appropriate metrics to minimize the divergence or distance between the knowledge extracted from the teacher model and the knowledge learned by the student model. Centered Kernel Alignment (CKA) is widely used to measure representation similarity and has been applied in several knowledge distillation methods. However, these methods are complex and fail to uncover the essence of CKA, thus not answering the question of how to use CKA to achieve simple and effective distillation properly. This paper first provides a theoretical perspective to illustrate the effectiveness of CKA, which decouples CKA to the upper bound of Maximum Mean Discrepancy~(MMD) and a constant term. Drawing from this, we propose a novel Relation-Centered Kernel Alignment~(RCKA) framework, which practically establishes a connection between CKA and MMD. Furthermore, we dynamically customize the application of CKA based on the characteristics of each task, with less computational source yet comparable performance than the previous methods. The extensive experiments on the CIFAR-100, ImageNet-1k, and MS-COCO demonstrate that our method achieves state-of-the-art performance on almost all teacher-student pairs for image classification and object detection, validating the effectiveness of our approaches. Our code is available in https://github.com/Klayand/PCKA

Rethinking Centered Kernel Alignment in Knowledge Distillation

TL;DR

This work reexamines Centered Kernel Alignment (CKA) for knowledge distillation by revealing that CKA approximates the upper bound of the Maximum Mean Discrepancy (MMD) plus a constant term, which motivates simpler and more effective distillation mechanisms. It introduces Relation-based Centered Kernel Alignment (RCKA) for robust, scalable alignment of high-order representations across teacher and student networks, and Patch-based CKA (PCKA) to adapt distillation to object detection by operating on patch-level Gram matrices. The authors validate their approach on CIFAR-100, ImageNet-1k, and MS-COCO, achieving state-of-the-art or competitive results while reducing computational overhead relative to prior CKA-centric methods. They also provide extensive ablations and visualizations to support theoretical claims and demonstrate the benefits of patching and channel-wise averaging, with code released for reproducibility and extension. Overall, the paper offers a principled, scalable pathway to leverage CKA in KD across diverse vision tasks.

Abstract

Knowledge distillation has emerged as a highly effective method for bridging the representation discrepancy between large-scale models and lightweight models. Prevalent approaches involve leveraging appropriate metrics to minimize the divergence or distance between the knowledge extracted from the teacher model and the knowledge learned by the student model. Centered Kernel Alignment (CKA) is widely used to measure representation similarity and has been applied in several knowledge distillation methods. However, these methods are complex and fail to uncover the essence of CKA, thus not answering the question of how to use CKA to achieve simple and effective distillation properly. This paper first provides a theoretical perspective to illustrate the effectiveness of CKA, which decouples CKA to the upper bound of Maximum Mean Discrepancy~(MMD) and a constant term. Drawing from this, we propose a novel Relation-Centered Kernel Alignment~(RCKA) framework, which practically establishes a connection between CKA and MMD. Furthermore, we dynamically customize the application of CKA based on the characteristics of each task, with less computational source yet comparable performance than the previous methods. The extensive experiments on the CIFAR-100, ImageNet-1k, and MS-COCO demonstrate that our method achieves state-of-the-art performance on almost all teacher-student pairs for image classification and object detection, validating the effectiveness of our approaches. Our code is available in https://github.com/Klayand/PCKA
Paper Structure (45 sections, 2 theorems, 10 equations, 7 figures, 17 tables)

This paper contains 45 sections, 2 theorems, 10 equations, 7 figures, 17 tables.

Key Result

Theorem 1

Let $X$ and $Y$ be $N \times P$ matrices. The CKA similarity $\|Y^\top X\|_F^2$ is equivalent to the cosine similarity of $XX^\top$ and $YY^\top$, which denote the gram matrix of $X$ and $Y$, respectively. In other words, where $\mathrm{vec}$ operator represents reshaping the matrix to a vector.

Figures (7)

  • Figure 1: The overall framework of the proposed Relation-based Centered Kernel Alignment (RCKA). We first transform the feature map from the shape of $(B, C, HW)$ into $(B, CHW)$ and then compute the CKA similarity of feature maps between the teacher and the student. Besides, we compute the inter-class and intra-class CKA similarity of logits between teacher and student. Here, $N$ refers to the number of samples, and $P$ refers to the corresponding probability of class to which this sample belongs.
  • Figure 2: The overall framework of PCKA. We dynamically customize the framework of proposed method based on the characteristics of object detection. In this framework, we first patch the featuremap of the teacher and student with the patchsize $(P_H, P_W)$, then transform the featuremap in order to get the gram matrix between each patch. Finally, we calculate the loss $L_{PCKA}$, and get average from dimension $C$. Here, $B, C, H,W$ refer to the batchsize, channels, height and width of the featuremap, respectively. $N_{P_H}, N_{P_W}$ denote the number of patches cutting along the height and width, respectively.
  • Figure 3: Training process visualization of all experimented detectors.
  • Figure 4: The effect of patch on different dimensions. This experiment is conducted on the RetinaNet-X101-RetinaNet-R50 pair
  • Figure 5: CKA curve in training phase on CIFAR-100. We visualize the CKA similarities in the training for six teacher-student pairs.
  • ...and 2 more figures

Theorems & Definitions (3)

  • Theorem 1: Proof in https://arxiv.org/abs/2401.11824 \ref{['appendix:proof:cosine']}
  • Theorem 2: Proof in https://arxiv.org/abs/2401.11824 \ref{['appendix:proof:mmd']}
  • proof