Table of Contents
Fetching ...

Channel-Selective Normalization for Label-Shift Robust Test-Time Adaptation

Pedro Vianna, Muawiz Chaudhary, Paria Mehrbod, An Tang, Guy Cloutier, Guy Wolf, Michael Eickenberg, Eugene Belilovsky

TL;DR

This work tackles the vulnerability of test-time adaptation via BatchNorm statistics (TTN) to label distribution shifts by introducing Hybrid-TTN, a channel-wise, depth-aware adaptation strategy. It computes per-channel, class-aware sensitivity scores using Wasserstein distances between source and target BN statistics, weighs them by a target class prior, and selects top channels to adapt with a depth-decaying threshold, forming hybrid BN statistics. The method is validated on CIFAR-10-C, ImageNet-1K-C, and liver ultrasound data, showing robust performance under label shifts while retaining covariate-shift benefits and exhibiting favorable median rankings compared to existing TTA methods. This approach offers a practical, hyperparameter-light path to safer, more reliable deployment of TTA in real-world domains, including biomedical imaging.

Abstract

Deep neural networks have useful applications in many different tasks, however their performance can be severely affected by changes in the data distribution. For example, in the biomedical field, their performance can be affected by changes in the data (different machines, populations) between training and test datasets. To ensure robustness and generalization to real-world scenarios, test-time adaptation has been recently studied as an approach to adjust models to a new data distribution during inference. Test-time batch normalization is a simple and popular method that achieved compelling performance on domain shift benchmarks. It is implemented by recalculating batch normalization statistics on test batches. Prior work has focused on analysis with test data that has the same label distribution as the training data. However, in many practical applications this technique is vulnerable to label distribution shifts, sometimes producing catastrophic failure. This presents a risk in applying test time adaptation methods in deployment. We propose to tackle this challenge by only selectively adapting channels in a deep network, minimizing drastic adaptation that is sensitive to label shifts. Our selection scheme is based on two principles that we empirically motivate: (1) later layers of networks are more sensitive to label shift (2) individual features can be sensitive to specific classes. We apply the proposed technique to three classification tasks, including CIFAR10-C, Imagenet-C, and diagnosis of fatty liver, where we explore both covariate and label distribution shifts. We find that our method allows to bring the benefits of TTA while significantly reducing the risk of failure common in other methods, while being robust to choice in hyperparameters.

Channel-Selective Normalization for Label-Shift Robust Test-Time Adaptation

TL;DR

This work tackles the vulnerability of test-time adaptation via BatchNorm statistics (TTN) to label distribution shifts by introducing Hybrid-TTN, a channel-wise, depth-aware adaptation strategy. It computes per-channel, class-aware sensitivity scores using Wasserstein distances between source and target BN statistics, weighs them by a target class prior, and selects top channels to adapt with a depth-decaying threshold, forming hybrid BN statistics. The method is validated on CIFAR-10-C, ImageNet-1K-C, and liver ultrasound data, showing robust performance under label shifts while retaining covariate-shift benefits and exhibiting favorable median rankings compared to existing TTA methods. This approach offers a practical, hyperparameter-light path to safer, more reliable deployment of TTA in real-world domains, including biomedical imaging.

Abstract

Deep neural networks have useful applications in many different tasks, however their performance can be severely affected by changes in the data distribution. For example, in the biomedical field, their performance can be affected by changes in the data (different machines, populations) between training and test datasets. To ensure robustness and generalization to real-world scenarios, test-time adaptation has been recently studied as an approach to adjust models to a new data distribution during inference. Test-time batch normalization is a simple and popular method that achieved compelling performance on domain shift benchmarks. It is implemented by recalculating batch normalization statistics on test batches. Prior work has focused on analysis with test data that has the same label distribution as the training data. However, in many practical applications this technique is vulnerable to label distribution shifts, sometimes producing catastrophic failure. This presents a risk in applying test time adaptation methods in deployment. We propose to tackle this challenge by only selectively adapting channels in a deep network, minimizing drastic adaptation that is sensitive to label shifts. Our selection scheme is based on two principles that we empirically motivate: (1) later layers of networks are more sensitive to label shift (2) individual features can be sensitive to specific classes. We apply the proposed technique to three classification tasks, including CIFAR10-C, Imagenet-C, and diagnosis of fatty liver, where we explore both covariate and label distribution shifts. We find that our method allows to bring the benefits of TTA while significantly reducing the risk of failure common in other methods, while being robust to choice in hyperparameters.
Paper Structure (18 sections, 2 equations, 7 figures, 10 tables, 2 algorithms)

This paper contains 18 sections, 2 equations, 7 figures, 10 tables, 2 algorithms.

Figures (7)

  • Figure 1: Effect of extreme class imbalance and covariate shift on test-time batch-normalized (TTN) performance: TTN mitigates distributional shift but greatly suffers from class imbalance. We show per-class accuracy plots for a Source model (blue), a class-balanced TTN model (orange), and a class-imbalanced (1-class) TTN model (green). Observe that the class-imbalanced TTN model performs very poorly on the most prevalent class label of the imbalanced adaptation set (label 7, red rectangle).
  • Figure 2: Illustration of explanation for depth-wise behaviour under label distribution shift. We consider one class mean (green) which is shifted towards the data mean, as would be the case in a highly imbalanced setting. Classes are not well separated in early layers and thus shifts in any mean are relatively small and non-intrusive. In later layers classes are well separated and a large shift of points from one mean towards the data mean is likely to cross a decision boundary. Data points in other classes moving away from the data mean are less likely to cross a decision boundary.
  • Figure 3: On CIFAR-10 we adapt models only up to the layer shown on the x-axis, the y-axis showing the accuracy on the target data. We consider label distributions with all (10) classes as well as 5,3, and 1 randomly selected and balanced classes. We observe that adapting later layers has an outsized role in the catastrophic collapse due to TTN.
  • Figure 4: We compute the BN statistics for all-class and multiple 1-class cases and compare agreement on channels which are most adapted (as measured by top-10% Wasserstein distance). The percentage of overlapping channels that presents the highest Wasserstein distance (between source and adapted models) in both tasks is represented by the blue bars, with orange bars representing the percentage of channels among those with the highest Wassertein distance that are not the same between the two evaluated cases.
  • Figure 5: CIFAR-10 evaluations on multiple label shifted distributions and covariate shifts (all corruptions) with different degrees of label imbalance. We show the source model accuracy and the improvement (or degradation) with our proposed method and baseline. We observe that the proposed method provides benefits when there is no covariate shift, while avoiding catastrophic failures and allowing benefits over source when there are label distribution shifts.
  • ...and 2 more figures