BiSSL: Enhancing the Alignment Between Self-Supervised Pretraining and Downstream Fine-Tuning via Bilevel Optimization
Gustav Wagner Zakarias, Lars Kai Hansen, Zheng-Hua Tan
TL;DR
BiSSL tackles the misalignment between self-supervised pretraining and downstream fine-tuning by embedding both stages in a bilevel optimization framework. The lower level optimizes the pretext objective with a regularizer that ties the backbone to the downstream configuration, while the upper level optimizes downstream performance, with gradients computed via an implicit Jacobian and approximated through conjugate gradients. Empirical results across 12 downstream vision tasks and object detection show BiSSL yielding consistent, sometimes substantial, performance gains and improved stability over conventional SSL pipelines, while remaining compatible with standard pretext methods like SimCLR and BYOL. The work demonstrates that learning a pretraining initialization that is explicitly aligned to downstream objectives can meaningfully enhance transfer learning effectiveness, especially under distributional shifts between pretraining and downstream data.
Abstract
Models initialized from self-supervised pretraining may suffer from poor alignment with downstream tasks, reducing the extent to which subsequent fine-tuning can adapt pretrained features toward downstream objectives. To mitigate this, we introduce BiSSL, a novel bilevel training framework that enhances the alignment of self-supervised pretrained models with downstream tasks prior to fine-tuning. BiSSL acts as an intermediate training stage conducted after conventional self-supervised pretraining and is tasked with solving a bilevel optimization problem that incorporates the pretext and downstream training objectives in its lower- and upper-level objectives, respectively. This approach explicitly models the interdependence between the pretraining and fine-tuning stages within the conventional self-supervised learning pipeline, facilitating enhanced information sharing between them that ultimately leads to a model initialization better aligned with the downstream task. We propose a general training algorithm for BiSSL that is compatible with a broad range of pretext and downstream tasks. Using SimCLR and Bootstrap Your Own Latent to pretrain ResNet-50 backbones on the ImageNet dataset, we demonstrate that our proposed framework significantly improves accuracy on the vast majority of 12 downstream image classification datasets, as well as on object detection. Exploratory analyses alongside investigative experiments further provide compelling evidence that BiSSL enhances downstream alignment.
