Table of Contents
Fetching ...

Label Distribution Shift-Aware Prediction Refinement for Test-Time Adaptation

Minguk Jang, Hye Won Chung

TL;DR

DART is introduced, a novel TTA method that refines the predictions by focusing on class-wise confusion patterns that leads to enhanced performance in existing TTA methods, making DART a valuable plug-in tool.

Abstract

Test-time adaptation (TTA) is an effective approach to mitigate performance degradation of trained models when encountering input distribution shifts at test time. However, existing TTA methods often suffer significant performance drops when facing additional class distribution shifts. We first analyze TTA methods under label distribution shifts and identify the presence of class-wise confusion patterns commonly observed across different covariate shifts. Based on this observation, we introduce label Distribution shift-Aware prediction Refinement for Test-time adaptation (DART), a novel TTA method that refines the predictions by focusing on class-wise confusion patterns. DART trains a prediction refinement module during an intermediate time by exposing it to several batches with diverse class distributions using the training dataset. This module is then used during test time to detect and correct class distribution shifts, significantly improving pseudo-label accuracy for test data. Our method exhibits 5-18% gains in accuracy under label distribution shifts on CIFAR-10C, without any performance degradation when there is no label distribution shift. Extensive experiments on CIFAR, PACS, OfficeHome, and ImageNet benchmarks demonstrate DART's ability to correct inaccurate predictions caused by test-time distribution shifts. This improvement leads to enhanced performance in existing TTA methods, making DART a valuable plug-in tool.

Label Distribution Shift-Aware Prediction Refinement for Test-Time Adaptation

TL;DR

DART is introduced, a novel TTA method that refines the predictions by focusing on class-wise confusion patterns that leads to enhanced performance in existing TTA methods, making DART a valuable plug-in tool.

Abstract

Test-time adaptation (TTA) is an effective approach to mitigate performance degradation of trained models when encountering input distribution shifts at test time. However, existing TTA methods often suffer significant performance drops when facing additional class distribution shifts. We first analyze TTA methods under label distribution shifts and identify the presence of class-wise confusion patterns commonly observed across different covariate shifts. Based on this observation, we introduce label Distribution shift-Aware prediction Refinement for Test-time adaptation (DART), a novel TTA method that refines the predictions by focusing on class-wise confusion patterns. DART trains a prediction refinement module during an intermediate time by exposing it to several batches with diverse class distributions using the training dataset. This module is then used during test time to detect and correct class distribution shifts, significantly improving pseudo-label accuracy for test data. Our method exhibits 5-18% gains in accuracy under label distribution shifts on CIFAR-10C, without any performance degradation when there is no label distribution shift. Extensive experiments on CIFAR, PACS, OfficeHome, and ImageNet benchmarks demonstrate DART's ability to correct inaccurate predictions caused by test-time distribution shifts. This improvement leads to enhanced performance in existing TTA methods, making DART a valuable plug-in tool.

Paper Structure

This paper contains 70 sections, 22 equations, 16 figures, 31 tables, 4 algorithms.

Figures (16)

  • Figure 1: Confusion patterns of BN-adapted classifier due to test-time label distribution shifts. We present class-wise confusion matrices of BN-adapted classifiers, initially trained on class-balanced CIFAR-10 and then tested on CIFAR-10C with two long-tailed distributions (first column). The second column shows confusion patterns on the CIFAR-10 test dataset with only label shifts, while the third to fifth columns display patterns on CIFAR-10C under three types of corruptions combined with label shifts. Class pairs where the confusion rate exceeds 11% are highlighted in red. There is significant accuracy degradation in head classes (e.g., class 0 in the 1st row and class 9 in the 2nd row). Additionally, similar class-wise confusion patterns are observed across different types of corruption for each label distribution shift (each row). Confusion matrices for other 15 types of corruption and class imbalance ratios of $\rho=1,10,100$ are also reported in Figure \ref{['fig: conf_matrices_all_rho1']}--\ref{['fig: conf_matrices_all_rho100']}.
  • Figure 2: Intermediate time training of DART. At intermediate time, the period between the training and test times, DART trains a prediction refinement module $g_\phi$ to correct the inaccurate prediction caused by the class distribution shifts. (left) By sampling the training data from Dirichlet distributions, we generate batches with diverse class distributions. (right) The prediction refinement module $g_\phi$ takes the averaged pseudo label distribution $\bar{p}_{\mathcal{B}}$ and prediction deviation $d_{\mathcal{B}}$ for each batch ${\mathcal{B}}$, and outputs a square matrix $W_{{\mathcal{B}}}$ and a vector $b_{\mathcal{B}}$ of size $K$ (class numbers). Using labels of the training data, we optimize $g_\phi$ to minimize the cross-entropy loss between labels $y$ and the refined prediction $q=\text{softmax}(f_{\bar{\theta}}(x) W_{\mathcal{B}} + b_{\mathcal{B}})$ for samples $(x,y) \in {\mathcal{B}}$ for the BN-adapted classifier $f_{\bar{\theta}}$.
  • Figure 3: We observe performance degradation of BNAdapt (orange) as the class imbalance ratio $\rho$ increases on long-tailed CIFAR-10C. DART-applied BNAdapt (green) shows consistently improved performance regardless of class imbalance.
  • Figure 4: We demonstrate the relationship between (left) test accuracy and the difference between averaged pseudo label distribution, $\frac{1}{|{\mathcal{B}}|}\sum_{x_i \in {\mathcal{B}}} \text{softmax}(f_{\bar{\theta}}(x_i))$, and uniform distribution, $u$, and (right) the test accuracy and the prediction deviation $d_{\mathcal{B}}$, for each batch ${\mathcal{B}}$ with the BN-adapted classifier for CIFAR-10C-imb under Gaussian noise of 5 different imbalance ratios. Each point represents a single batch within the test dataset.
  • Figure 5: Class distribution of PACS and OfficeHome
  • ...and 11 more figures