Table of Contents
Fetching ...

Test-Time Training with Self-Supervision for Generalization under Distribution Shifts

Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei A. Efros, Moritz Hardt

TL;DR

This work introduces Test-Time Training (TTT), a principled approach to improve model generalization under distribution shifts by performing self-supervised updates on unlabeled test samples before predicting. TT T uses a shared feature extractor and two task branches (main and self-supervised) and optionally supports an online variant that updates sequentially in a data stream. Empirically, TT T yields substantial robustness gains across CIFAR-10-C, ImageNet-C, VID-Robust, and CIFAR-10.1, while preserving or marginally improving original-distribution performance; theory shows that gradient correlation between the two losses guarantees improvement in convex settings, with empirical evidence extending to deep networks. The results suggest a practical, deployment-time adaptation paradigm and motivate further exploration of task-design and efficiency for broader robustness applications.

Abstract

In this paper, we propose Test-Time Training, a general approach for improving the performance of predictive models when training and test data come from different distributions. We turn a single unlabeled test sample into a self-supervised learning problem, on which we update the model parameters before making a prediction. This also extends naturally to data in an online stream. Our simple approach leads to improvements on diverse image classification benchmarks aimed at evaluating robustness to distribution shifts.

Test-Time Training with Self-Supervision for Generalization under Distribution Shifts

TL;DR

This work introduces Test-Time Training (TTT), a principled approach to improve model generalization under distribution shifts by performing self-supervised updates on unlabeled test samples before predicting. TT T uses a shared feature extractor and two task branches (main and self-supervised) and optionally supports an online variant that updates sequentially in a data stream. Empirically, TT T yields substantial robustness gains across CIFAR-10-C, ImageNet-C, VID-Robust, and CIFAR-10.1, while preserving or marginally improving original-distribution performance; theory shows that gradient correlation between the two losses guarantees improvement in convex settings, with empirical evidence extending to deep networks. The results suggest a practical, deployment-time adaptation paradigm and motivate further exploration of task-design and efficiency for broader robustness applications.

Abstract

In this paper, we propose Test-Time Training, a general approach for improving the performance of predictive models when training and test data come from different distributions. We turn a single unlabeled test sample into a self-supervised learning problem, on which we update the model parameters before making a prediction. This also extends naturally to data in an online stream. Our simple approach leads to improvements on diverse image classification benchmarks aimed at evaluating robustness to distribution shifts.

Paper Structure

This paper contains 37 sections, 25 equations, 11 figures, 11 tables.

Figures (11)

  • Figure 1: Test error (%) on CIFAR-10-C with level 5 corruptions. We compare our approaches, Test-Time Training (TTT) and its online version (TTT-Online), with two baselines: object recognition without self-supervision, and joint training with self-supervision but keeping the model fixed at test time. TTT improves over the baselines and TTT-Online improves even further.
  • Figure 2: Test accuracy (%) on ImageNet-C with level 5 corruptions. Upper panel: Our approaches, TTT and TTT-Online, show significant improvements in all corruption types over the two baselines. Lower panel: We show the accuracy of TTT-Online as the average over a sliding window of 100 samples; TTT-Online generalizes better as more samples are evaluated (x-axis), without hurting on the original distribution. We use accuracy instead of error here because the baseline performance is very low for most corruptions.
  • Figure 3: Test error (%) on CIFAR-10-C, for the three noise types, with gradually changing distribution. The distribution shifts are created by increasing the standard deviation of each noise type from small to large, the further we go on the x-axis. As the samples get noisier, all methods suffer greater errors the more we evaluate into the test set, but online Test-Time Training (TTT-Online) achieves gentler slopes than joint training. For the first two noise types, TTT-Online also achieves better results over unsupervised domain adaptation by self-supervision (UDA-SS) sun2019uda.
  • Figure 4: Scatter plot of the inner product between the gradients (on the shared feature extractor ${\bm{\theta}}_e$) of the main task $l_m$ and the self-supervised task $l_e$, and the improvement in test error (%) from Test-Time Training, for the standard (left) and online (right) version. Each point is the average over a test set, and each scatter plot has 75 test sets, from all 15 types of corruptions over five levels as described in \ref{['results_cc']}. The blue lines and bands are the best linear fits and the 99% confidence intervals. The linear correlation coefficients are $0.93$ and $0.89$ respectively, indicating strong positive correlation between the two quantities, as suggested by \ref{['main_theorem']}.
  • Figure A1: Sample images from the Common Corruptions Benchmark, taken from the original paper by hendrycks2019benchmarking.
  • ...and 6 more figures