Table of Contents
Fetching ...

Using Early Readouts to Mediate Featural Bias in Distillation

Rishabh Tiwari, Durga Sivasubramanian, Anmol Mekala, Ganesh Ramakrishnan, Pradeep Shenoy

TL;DR

This work proposes a novel early readout mechanism whereby it attempts to predict the label using representations from earlier network layers, and shows that these early readouts automatically identify problem instances or groups in the form of confident, incorrect predictions.

Abstract

Deep networks tend to learn spurious feature-label correlations in real-world supervised learning tasks. This vulnerability is aggravated in distillation, where a student model may have lesser representational capacity than the corresponding teacher model. Often, knowledge of specific spurious correlations is used to reweight instances & rebalance the learning process. We propose a novel early readout mechanism whereby we attempt to predict the label using representations from earlier network layers. We show that these early readouts automatically identify problem instances or groups in the form of confident, incorrect predictions. Leveraging these signals to modulate the distillation loss on an instance level allows us to substantially improve not only group fairness measures across benchmark datasets, but also overall accuracy of the student model. We also provide secondary analyses that bring insight into the role of feature learning in supervision and distillation.

Using Early Readouts to Mediate Featural Bias in Distillation

TL;DR

This work proposes a novel early readout mechanism whereby it attempts to predict the label using representations from earlier network layers, and shows that these early readouts automatically identify problem instances or groups in the form of confident, incorrect predictions.

Abstract

Deep networks tend to learn spurious feature-label correlations in real-world supervised learning tasks. This vulnerability is aggravated in distillation, where a student model may have lesser representational capacity than the corresponding teacher model. Often, knowledge of specific spurious correlations is used to reweight instances & rebalance the learning process. We propose a novel early readout mechanism whereby we attempt to predict the label using representations from earlier network layers. We show that these early readouts automatically identify problem instances or groups in the form of confident, incorrect predictions. Leveraging these signals to modulate the distillation loss on an instance level allows us to substantially improve not only group fairness measures across benchmark datasets, but also overall accuracy of the student model. We also provide secondary analyses that bring insight into the role of feature learning in supervision and distillation.
Paper Structure (28 sections, 4 equations, 7 figures, 7 tables, 1 algorithm)

This paper contains 28 sections, 4 equations, 7 figures, 7 tables, 1 algorithm.

Figures (7)

  • Figure 1: (a) We use predictions from an auxiliary layer applied on top of early features to determine the weights for the distillation loss. Errors from the readouts are disproportionately from learned spurious features. (b) Comparison of Worst Group Accuracies (WGA) relative to that of the Teacher's. DeDiER is best in being able to match the Teacher's WGA.
  • Figure 2: Early readout errors recall worst group instances (left) and worst group readouts are more confident, across layers (right). We measure linear decoding error and confidence margins at each layer, after 1 epoch. (a, c): We observe that nearly all worst group instances are misclassified ($\sim 100$% recall) with more recall in earlier layers. (b, d): show that error instances from minority groups have significantly higher confidence margin compared to other groups. See text for more details.
  • Figure 3: Weighing scheme as a function of confidence of instances mispredicted by early readouts, for different values of $\alpha$ and $\beta$. a) shows weighing scheme for different values of $\alpha$ keeping $\beta$=4. b) shows weighing scheme for different values of $\beta$ keeping $\alpha$=0.1
  • Figure 4: Groups in the four datasets. Groups which follow the correlation are in green and ones in conflict with the correlation are in red.
  • Figure 5: Evolution of reweighting during distillation (Waterbirds dataset). Top row shows error rate, and confidence of error instances, at the early readout broken down by groups. As expected, conflicting groups have high error rates due to spurious features; through the distillation process, the overconfidence reduces. Bottom row shows average weights for each group in the distillation loss (\ref{['eq:final']}), and the error rate at final layer. At the end of training, the groupwise accuracies are reconciled.
  • ...and 2 more figures