Table of Contents
Fetching ...

The Role of Masking for Efficient Supervised Knowledge Distillation of Vision Transformers

Seungwoo Son, Jegwang Ryu, Namhoon Lee, Jaeho Lee

TL;DR

A simple framework to reduce the supervision cost of ViT distillation by masking out a fraction of input tokens given to the teacher, revealing that the student-guided masking provides a good curriculum to the student, making teacher supervision easier to follow during the early stage and challenging in the later stage.

Abstract

Knowledge distillation is an effective method for training lightweight vision models. However, acquiring teacher supervision for training samples is often costly, especially from large-scale models like vision transformers (ViTs). In this paper, we develop a simple framework to reduce the supervision cost of ViT distillation: masking out a fraction of input tokens given to the teacher. By masking input tokens, one can skip the computations associated with the masked tokens without requiring any change to teacher parameters or architecture. We find that masking patches with the lowest student attention scores is highly effective, saving up to 50% of teacher FLOPs without any drop in student accuracy, while other masking criterion leads to suboptimal efficiency gains. Through in-depth analyses, we reveal that the student-guided masking provides a good curriculum to the student, making teacher supervision easier to follow during the early stage and challenging in the later stage.

The Role of Masking for Efficient Supervised Knowledge Distillation of Vision Transformers

TL;DR

A simple framework to reduce the supervision cost of ViT distillation by masking out a fraction of input tokens given to the teacher, revealing that the student-guided masking provides a good curriculum to the student, making teacher supervision easier to follow during the early stage and challenging in the later stage.

Abstract

Knowledge distillation is an effective method for training lightweight vision models. However, acquiring teacher supervision for training samples is often costly, especially from large-scale models like vision transformers (ViTs). In this paper, we develop a simple framework to reduce the supervision cost of ViT distillation: masking out a fraction of input tokens given to the teacher. By masking input tokens, one can skip the computations associated with the masked tokens without requiring any change to teacher parameters or architecture. We find that masking patches with the lowest student attention scores is highly effective, saving up to 50% of teacher FLOPs without any drop in student accuracy, while other masking criterion leads to suboptimal efficiency gains. Through in-depth analyses, we reveal that the student-guided masking provides a good curriculum to the student, making teacher supervision easier to follow during the early stage and challenging in the later stage.
Paper Structure (51 sections, 1 theorem, 7 equations, 9 figures, 12 tables)

This paper contains 51 sections, 1 theorem, 7 equations, 9 figures, 12 tables.

Key Result

proposition thmcounterproposition

Let $\mathbf{c} \in \mathbb{R}^d, W_q, W_k \in \mathbb{R}^{d \times d}$ be the class token and the query/key weight matrices whose entries are i.i.d. initialized as $\mathcal{N}(0,1/d)$. Let $f(\cdot)$ be the pre-softmax attention of the class token to another token, i.e., $f(\mathbf{x}) := (W_q \ma where $C_0$ is a constant independent of $d$.

Figures (9)

  • Figure 1: ($\Leftarrow$) Supervision cost is expensive. We compare the per-step supervision cost of the teacher ViT that sees full/masked images, with the training FLOPs of the student. We mask the teacher input to a point where there is no student accuracy drop. Supervision cost is larger than student FLOPs, and masking can save a great amount. ($\Rightarrow$) MaskedKD illustrated. MaskedKD works in four steps: (1) Student predicts on the full image. (2) Mask the image with the student's attention score. (3) Teacher predicts on the masked image. (4) Match the teacher-student logits.
  • Figure 2: ($\Leftarrow$) Student-guided masking illustrated. We mask the teacher's input tokens based on the student's attention scores of the class token query in its final layer. ($\Rightarrow$) Accuracy vs. # patches seen. Masking the teacher substantially degrades the teacher accuracy. The student accuracy, however, slightly increases first, and then starts decreasing after masking over 50% of the patches.
  • Figure 3: ($\Leftarrow$) Student accuracy gain from distillation, with various masks. The student-guided masking achieves the most rapid student accuracy increase in the early phase (averaged over three seeds). ($\Rightarrow$) Accuracy of masked teachers. The student-guided masking leads to the lowest teacher accuracy during the early phase.
  • Figure 4: Student-guided masking lets teacher supervise on diverse views. We visualize the utilization frequency of each patch throughout the training. The student-guided masking lets teacher predict on diverse patches (unlike "Teacher" or "DINO"), while conveying the core semantic information of the image (unlike "Random").
  • Figure 5: ($\Uparrow$) Randomly initialized students can mask similar patches at once. Randomly initialized students tend to mask all similar patches at the same time, often removing all foreground or background objects at once. We provide additional examples in Appendix B. ($\Downarrow$) Student shifts attention from periphery to center. We visualize the patch selection frequency of MaskedKD at different stages of training. Early in training, the student attend more on peripheral patches. As the training proceeds, the student shifts the attention to the central region.
  • ...and 4 more figures

Theorems & Definitions (2)

  • proposition thmcounterproposition
  • proof