Table of Contents
Fetching ...

Investigating the Benefits of Projection Head for Representation Learning

Yihao Xue, Eric Gan, Jiayi Ni, Siddharth Joshi, Baharan Mirzasoleiman

TL;DR

This paper addresses why adding a projection head during training improves representation quality across self-supervised, supervised contrastive, and standard supervised learning. By analyzing a two-layer linear model with a spectral contrastive loss, it reveals layer-wise progressive feature weighting where deeper layers emphasize a subset of features. It shows that non-linearities enable lower layers to acquire features absent from higher layers, and that the projection head can boost robustness under misalignment between pretraining and downstream objectives. The authors validate the theory with controlled experiments and real-data tasks, and propose a fixed reweighting head as an interpretable alternative that achieves comparable gains, with practical implications for robust transfer.

Abstract

An effective technique for obtaining high-quality representations is adding a projection head on top of the encoder during training, then discarding it and using the pre-projection representations. Despite its proven practical effectiveness, the reason behind the success of this technique is poorly understood. The pre-projection representations are not directly optimized by the loss function, raising the question: what makes them better? In this work, we provide a rigorous theoretical answer to this question. We start by examining linear models trained with self-supervised contrastive loss. We reveal that the implicit bias of training algorithms leads to layer-wise progressive feature weighting, where features become increasingly unequal as we go deeper into the layers. Consequently, lower layers tend to have more normalized and less specialized representations. We theoretically characterize scenarios where such representations are more beneficial, highlighting the intricate interplay between data augmentation and input features. Additionally, we demonstrate that introducing non-linearity into the network allows lower layers to learn features that are completely absent in higher layers. Finally, we show how this mechanism improves the robustness in supervised contrastive learning and supervised learning. We empirically validate our results through various experiments on CIFAR-10/100, UrbanCars and shifted versions of ImageNet. We also introduce a potential alternative to projection head, which offers a more interpretable and controllable design.

Investigating the Benefits of Projection Head for Representation Learning

TL;DR

This paper addresses why adding a projection head during training improves representation quality across self-supervised, supervised contrastive, and standard supervised learning. By analyzing a two-layer linear model with a spectral contrastive loss, it reveals layer-wise progressive feature weighting where deeper layers emphasize a subset of features. It shows that non-linearities enable lower layers to acquire features absent from higher layers, and that the projection head can boost robustness under misalignment between pretraining and downstream objectives. The authors validate the theory with controlled experiments and real-data tasks, and propose a fixed reweighting head as an interpretable alternative that achieves comparable gains, with practical implications for robust transfer.

Abstract

An effective technique for obtaining high-quality representations is adding a projection head on top of the encoder during training, then discarding it and using the pre-projection representations. Despite its proven practical effectiveness, the reason behind the success of this technique is poorly understood. The pre-projection representations are not directly optimized by the loss function, raising the question: what makes them better? In this work, we provide a rigorous theoretical answer to this question. We start by examining linear models trained with self-supervised contrastive loss. We reveal that the implicit bias of training algorithms leads to layer-wise progressive feature weighting, where features become increasingly unequal as we go deeper into the layers. Consequently, lower layers tend to have more normalized and less specialized representations. We theoretically characterize scenarios where such representations are more beneficial, highlighting the intricate interplay between data augmentation and input features. Additionally, we demonstrate that introducing non-linearity into the network allows lower layers to learn features that are completely absent in higher layers. Finally, we show how this mechanism improves the robustness in supervised contrastive learning and supervised learning. We empirically validate our results through various experiments on CIFAR-10/100, UrbanCars and shifted versions of ImageNet. We also introduce a potential alternative to projection head, which offers a more interpretable and controllable design.
Paper Structure (29 sections, 13 theorems, 39 equations, 10 figures, 2 tables)

This paper contains 29 sections, 13 theorems, 39 equations, 10 figures, 2 tables.

Key Result

Theorem 3.3

The global minimizer of the CL loss $\mathcal{L}_{CL}$ with the smallest norm, defined as $\|\pmb{W}_1^\top \pmb{W}_1\|_F^2 + \|\pmb{W}_2^{\top} \pmb{W}_2\|_F^2$, satisfies $\pmb{W}_1 \pmb{W}_1^{\top} = \pmb{W}_2^{\top} \pmb{W}_2$.

Figures (10)

  • Figure 1: Weights of features in a two-layer fully connected ReLU network trained with CL. Left: With all features having equal strength, those that are more disrupted by augmentation have smaller/zero weights. Right: With augmentation treating all features equally. features with the largest/smallest strength are weighted less compared to those with intermediate strength. In both right and left, weights are more equal pre-projection.
  • Figure 2: Left: Weights of features in a two-layer fully connected ReLU network trained using SCL. The subclass feature is not represented post-projection but is represented pre-projection. Right: As a result, the four subclasses are only separable in pre-projection representations.
  • Figure 3: Results on MNIST-on-CIFAR-10. (a) The data augmentation keeps the digit for one image in the positive pair and randomly drops the digit for the other with a probability $p_{\text{drop}}$. Pre-projection is more beneficial (b) with more inappropriate augmentation (large $p_{\text{drop}}$) during pretraining, (c) when digits are very weak/strong during pretraining, and (d) when weight decay is smaller.
  • Figure 4: Performance of few-shot adaption.
  • Figure 5: The value of the sample complexity indicator at different layers. Legends show the weight assigned to the downstream relevant weight by the pretrained model. We observe that the optimal layer shifts lower when the pretrained model assigns less weight to the downstream relevant features.
  • ...and 5 more figures

Theorems & Definitions (23)

  • Definition 3.1: Input distribution of pretraining data
  • Definition 3.2: Data augmentation
  • Theorem 3.3: Weights of the minimum norm minimizer
  • Theorem 3.4: Weights of the model trained with gradient flow, proved in arora2018optimization
  • Theorem 3.5
  • Theorem 3.6
  • Corollary 3.7
  • Theorem 3.8
  • Definition 4.1
  • Theorem 4.2
  • ...and 13 more