Investigating the Benefits of Projection Head for Representation Learning
Yihao Xue, Eric Gan, Jiayi Ni, Siddharth Joshi, Baharan Mirzasoleiman
TL;DR
This paper addresses why adding a projection head during training improves representation quality across self-supervised, supervised contrastive, and standard supervised learning. By analyzing a two-layer linear model with a spectral contrastive loss, it reveals layer-wise progressive feature weighting where deeper layers emphasize a subset of features. It shows that non-linearities enable lower layers to acquire features absent from higher layers, and that the projection head can boost robustness under misalignment between pretraining and downstream objectives. The authors validate the theory with controlled experiments and real-data tasks, and propose a fixed reweighting head as an interpretable alternative that achieves comparable gains, with practical implications for robust transfer.
Abstract
An effective technique for obtaining high-quality representations is adding a projection head on top of the encoder during training, then discarding it and using the pre-projection representations. Despite its proven practical effectiveness, the reason behind the success of this technique is poorly understood. The pre-projection representations are not directly optimized by the loss function, raising the question: what makes them better? In this work, we provide a rigorous theoretical answer to this question. We start by examining linear models trained with self-supervised contrastive loss. We reveal that the implicit bias of training algorithms leads to layer-wise progressive feature weighting, where features become increasingly unequal as we go deeper into the layers. Consequently, lower layers tend to have more normalized and less specialized representations. We theoretically characterize scenarios where such representations are more beneficial, highlighting the intricate interplay between data augmentation and input features. Additionally, we demonstrate that introducing non-linearity into the network allows lower layers to learn features that are completely absent in higher layers. Finally, we show how this mechanism improves the robustness in supervised contrastive learning and supervised learning. We empirically validate our results through various experiments on CIFAR-10/100, UrbanCars and shifted versions of ImageNet. We also introduce a potential alternative to projection head, which offers a more interpretable and controllable design.
