Table of Contents
Fetching ...

Views Can Be Deceiving: Improved SSL Through Feature Space Augmentation

Kimia Hamidieh, Haoran Zhang, Swami Sankaranarayanan, Marzyeh Ghassemi

TL;DR

This work empirically shows that commonly used augmentations in SSL can cause undesired invariances in the image space, and proposes LateTVG to remove spurious information from these representations during pre-training, by regularizing later layers of the encoder via pruning.

Abstract

Supervised learning methods have been found to exhibit inductive biases favoring simpler features. When such features are spuriously correlated with the label, this can result in suboptimal performance on minority subgroups. Despite the growing popularity of methods which learn from unlabeled data, the extent to which these representations rely on spurious features for prediction is unclear. In this work, we explore the impact of spurious features on Self-Supervised Learning (SSL) for visual representation learning. We first empirically show that commonly used augmentations in SSL can cause undesired invariances in the image space, and illustrate this with a simple example. We further show that classical approaches in combating spurious correlations, such as dataset re-sampling during SSL, do not consistently lead to invariant representations. Motivated by these findings, we propose LateTVG to remove spurious information from these representations during pre-training, by regularizing later layers of the encoder via pruning. We find that our method produces representations which outperform the baselines on several benchmarks, without the need for group or label information during SSL.

Views Can Be Deceiving: Improved SSL Through Feature Space Augmentation

TL;DR

This work empirically shows that commonly used augmentations in SSL can cause undesired invariances in the image space, and proposes LateTVG to remove spurious information from these representations during pre-training, by regularizing later layers of the encoder via pruning.

Abstract

Supervised learning methods have been found to exhibit inductive biases favoring simpler features. When such features are spuriously correlated with the label, this can result in suboptimal performance on minority subgroups. Despite the growing popularity of methods which learn from unlabeled data, the extent to which these representations rely on spurious features for prediction is unclear. In this work, we explore the impact of spurious features on Self-Supervised Learning (SSL) for visual representation learning. We first empirically show that commonly used augmentations in SSL can cause undesired invariances in the image space, and illustrate this with a simple example. We further show that classical approaches in combating spurious correlations, such as dataset re-sampling during SSL, do not consistently lead to invariant representations. Motivated by these findings, we propose LateTVG to remove spurious information from these representations during pre-training, by regularizing later layers of the encoder via pruning. We find that our method produces representations which outperform the baselines on several benchmarks, without the need for group or label information during SSL.
Paper Structure (54 sections, 2 theorems, 7 equations, 7 figures, 11 tables, 1 algorithm)

This paper contains 54 sections, 2 theorems, 7 equations, 7 figures, 11 tables, 1 algorithm.

Key Result

Lemma 3.3

Consider the set of (unlabeled) population data $\mathcal{X}$ in a binary-class setting where the spurious attribute takes binary values, consisting of $|\mathcal{G}| = 4$ groups, with the same number of examples per group. Consider a simplified augmentation graph with parameters $\alpha$, $\beta$,

Figures (7)

  • Figure 1: Analysing SSL augmentations. (a) Images generated from a latent space with correlating features. (b) If the connectivity induced by SSL augmentations between subgroups with the same spurious features is higher than the ones with the same invariant features, learned representations lead a downstream linear model to separate the data based on the spurious feature (red dashed line) instead of the invariant feature (green dashed line). Our empirical evaluation in Table \ref{['tab:connectivity']} shows that this is indeed the case across different datasets considered in this work.
  • Figure 2: We use model transformation modules to create new views of training examples in the representation space. The introduced set of transformations removes the features learned in the final few layers, and provides final representations invariant to such transformations.
  • Figure 3: Downstream worst-group accuracy of SSL-Late-TVG on the metashift dataset as we vary the percentage of minority group in the downstream training set. For all cases except for extreme minority decrement, SSL-Late-TVG outperforms the baseline.
  • Figure 4: Downstream worst-group accuracy of SSL-Late-TVG on the metashift (left) and celebA (right) datasets as we vary the model pruning hyperparameters.
  • Figure 5: We use Grad-CAM to explain the ResNet-18 SSL-base (top), and SSL-LateTVG model (bottom) for majority examples
  • ...and 2 more figures

Theorems & Definitions (4)

  • Definition 3.1
  • Lemma 3.3
  • Corollary 3.4
  • proof