Deep Fusion: Capturing Dependencies in Contrastive Learning via Transformer Projection Heads
Huanran Li, Daniel Pimentel-Alarcón
TL;DR
This work investigates replacing the standard feed-forward projection head in contrastive learning with a Transformer-based projection head to capture long-range dependencies among embeddings. It introduces Deep Fusion, an unsupervised phenomenon where attention layers progressively group samples from the same class, and provides a theoretical framework for this behavior. Empirically, Transformer projection heads yield improvements over FFN heads across CIFAR-10/100 and ImageNet-200, with notable gains in both supervised and unsupervised evaluations and with ablations clarifying optimal batch size, temperature, and weight decay. The results suggest that attention-based projection heads are a promising direction for enhancing self-supervised representation learning in vision.
Abstract
Contrastive Learning (CL) has emerged as a powerful method for training feature extraction models using unlabeled data. Recent studies suggest that incorporating a linear projection head post-backbone significantly enhances model performance. In this work, we investigate the use of a transformer model as a projection head within the CL framework, aiming to exploit the transformer's capacity for capturing long-range dependencies across embeddings to further improve performance. Our key contributions are fourfold: First, we introduce a novel application of transformers in the projection head role for contrastive learning, marking the first endeavor of its kind. Second, our experiments reveal a compelling "Deep Fusion" phenomenon where the attention mechanism progressively captures the correct relational dependencies among samples from the same class in deeper layers. Third, we provide a theoretical framework that explains and supports this "Deep Fusion" behavior. Finally, we demonstrate through experimental results that our model achieves superior performance compared to the existing approach of using a feed-forward layer.
