Table of Contents
Fetching ...

Connect Later: Improving Fine-tuning for Robustness with Targeted Augmentations

Helen Qu, Sang Michael Xie

TL;DR

It is shown on real-world tasks that standard fine-tuning after pretraining does not consistently improve OOD error over simply training from scratch on labeled source data, and Connect Later improves average OOD error over standard fine-tuning and supervised learning with targeted augmentations on 4 real-world datasets.

Abstract

Models trained on a labeled source domain (e.g., labeled images from wildlife camera traps) often generalize poorly when deployed on an out-of-distribution (OOD) target domain (e.g., images from new camera trap locations). In the domain adaptation setting where unlabeled target data is available, self-supervised pretraining (e.g., masked autoencoding or contrastive learning) is a promising method to mitigate this performance drop. Pretraining improves OOD error when the generic data augmentations used (e.g., masking or cropping) connect the source and target domains, which may be far apart in the input space. In this paper, we show on real-world tasks that standard fine-tuning after pretraining does not consistently improve OOD error over simply training from scratch on labeled source data. To better leverage pretraining for distribution shifts, we propose Connect Later: after pretraining with generic augmentations, fine-tune with targeted augmentations designed with knowledge of the distribution shift. Pretraining learns good representations within the source and target domains, while targeted augmentations connect the domains better during fine-tuning. Connect Later improves average OOD error over standard fine-tuning and supervised learning with targeted augmentations on 4 real-world datasets: Connect Later achieves the state-of-the-art on astronomical time-series classification (AstroClassification) by 2.5%, wildlife species identification (iWildCam-WILDS) with ResNet-50 by 0.9%, and tumor identification (Camelyon17-WILDS) with DenseNet121 by 1.1%; as well as best performance on a new dataset for astronomical time-series redshift prediction (Redshifts) by 0.03 RMSE (11% relative). Code and datasets are available at https://github.com/helenqu/connect-later.

Connect Later: Improving Fine-tuning for Robustness with Targeted Augmentations

TL;DR

It is shown on real-world tasks that standard fine-tuning after pretraining does not consistently improve OOD error over simply training from scratch on labeled source data, and Connect Later improves average OOD error over standard fine-tuning and supervised learning with targeted augmentations on 4 real-world datasets.

Abstract

Models trained on a labeled source domain (e.g., labeled images from wildlife camera traps) often generalize poorly when deployed on an out-of-distribution (OOD) target domain (e.g., images from new camera trap locations). In the domain adaptation setting where unlabeled target data is available, self-supervised pretraining (e.g., masked autoencoding or contrastive learning) is a promising method to mitigate this performance drop. Pretraining improves OOD error when the generic data augmentations used (e.g., masking or cropping) connect the source and target domains, which may be far apart in the input space. In this paper, we show on real-world tasks that standard fine-tuning after pretraining does not consistently improve OOD error over simply training from scratch on labeled source data. To better leverage pretraining for distribution shifts, we propose Connect Later: after pretraining with generic augmentations, fine-tune with targeted augmentations designed with knowledge of the distribution shift. Pretraining learns good representations within the source and target domains, while targeted augmentations connect the domains better during fine-tuning. Connect Later improves average OOD error over standard fine-tuning and supervised learning with targeted augmentations on 4 real-world datasets: Connect Later achieves the state-of-the-art on astronomical time-series classification (AstroClassification) by 2.5%, wildlife species identification (iWildCam-WILDS) with ResNet-50 by 0.9%, and tumor identification (Camelyon17-WILDS) with DenseNet121 by 1.1%; as well as best performance on a new dataset for astronomical time-series redshift prediction (Redshifts) by 0.03 RMSE (11% relative). Code and datasets are available at https://github.com/helenqu/connect-later.
Paper Structure (63 sections, 1 theorem, 10 equations, 5 figures, 8 tables)

This paper contains 63 sections, 1 theorem, 10 equations, 5 figures, 8 tables.

Key Result

Proposition 1

With the above construction for the input space $\mathcal{X}$, unlabeled distribution $P_U$, and data augmentation $\mathcal{A}_{\text{prop}}$, for some feature dimension $k \in \mathbb{Z}^+$ a linear probe trained on contrastive pre-trained features achieves 0 target error: $\mathcal{L}_{0-1}(\wide

Figures (5)

  • Figure 1: Overview of the Connect Later framework applied to a toy binary classification problem with two domains (filled and unfilled points), showing the representations from contrastive pretraining in $\mathbb{R}^2$. (Left) After contrastive pretraining with generic augmentations, the classes within each domain are linearly separable in representation space. Since the domains are far apart in input space, generic augmentations may misalign the pretrained representations. In this case, a classifier (with a linearly extrapolating decision boundary, dashed and solid line) learned on labeled source data will misclassify the target data. (Right) Connect Later employs targeted augmentations (filled points with black border) designed with knowledge of the distribution shift to connect the domains better, resulting in a classifier that generalizes well to the target domain.
  • Figure 2: Examples from the source dataset (left), an augmented version of the source example (middle), and the target dataset (right) for each of our tasks. (Top row) The AstroClassification and Redshifts tasks focus on time-varying astronomical objects observed in multiple wavelength ranges, plotted here as a multicolored time-series with each color corresponding to the wavelength range of the measurement. The redshifting augmentation simulates placing source objects at a higher redshift to better match the redshift distribution of the target dataset. The flux errors and flux values of the augmented example (middle) show much better resemblance to the target example. (Middle row) We randomize the habitat background by applying the Copy-Paste Same Y augmentation for iWildCam-WILDS (iWildCam-WILDS image examples shown here are from gao2023targeted). (Bottom row) Stain Color Jitter alters the overall color of source images in Camelyon17-WILDS to improve performance on images from unseen hospitals.
  • Figure 3: On the AstroClassification task, Connect Later is relatively robust to pretraining masking percentage both ID and OOD, but 60% masking performs best out of the percentages we tested.
  • Figure 4: Redshift distributions of source, augmented, and target datasets for the AstroClassification and Redshifts tasks.
  • Figure 5: Example distribution of data and augmentations for contrastive learning where Connect Later improves OOD performance over contrastive pretraining+standard fine-tuning and ERM+targeted augmentations. The augmentation graph is similar to shen2022connect except the edge weights connecting 1,2 and 3,4 are swapped. The shapes represent classes, while the labeled data is shaded in green. The generic augmentation probabilities are marked as edge weights, where we assume that $\alpha > \gamma + \beta$. Here, targeted augmentations which first swap inputs 1 and 2 before applying a generic augmentation help to align the source and target. However, some target inputs are not reachable via augmentations from source inputs. Standard fine-tuning can generalize throughout the target domain, but only in conjunction with targeted augmentations that align the source and target. The orange dotted lines on the far ends connect to each other (the graph wraps around).

Theorems & Definitions (1)

  • Proposition 1: shen2022connect