MIM-Refiner: A Contrastive Learning Boost from Intermediate Pre-Trained Representations
Benedikt Alkin, Lukas Miklautz, Sepp Hochreiter, Johannes Brandstetter
TL;DR
The paper addresses the mismatch between MIM pretraining and downstream task readiness by showing that meaningful, abstract representations in MIM models concentrate in intermediate encoder blocks. It introduces MIM-Refiner, a sequential refinement method that attaches multiple Instance Discrimination heads to late encoder blocks and uses Nearest Neighbor Alignment to form semantic clusters without altering the primary MIM objective. Empirically, it delivers state-of-the-art linear probing and strong low-shot, clustering, and transfer performance across multiple MIM backbones, including large ViT models, while remaining compute-efficient. The approach demonstrates strong potential as a scalable foundation-model refinement technique, applicable to diverse downstream tasks with minimal training overhead.
Abstract
We introduce MIM (Masked Image Modeling)-Refiner, a contrastive learning boost for pre-trained MIM models. MIM-Refiner is motivated by the insight that strong representations within MIM models generally reside in intermediate layers. Accordingly, MIM-Refiner leverages multiple contrastive heads that are connected to different intermediate layers. In each head, a modified nearest neighbor objective constructs semantic clusters that capture semantic information which improves performance on downstream tasks, including off-the-shelf and fine-tuning settings. The refinement process is short and simple - yet highly effective. Within a few epochs, we refine the features of MIM models from subpar to state-of-the-art, off-the-shelf features. Refining a ViT-H, pre-trained with data2vec 2.0 on ImageNet-1K, sets a new state-of-the-art in linear probing (84.7%) and low-shot classification among models that are pre-trained on ImageNet-1K. MIM-Refiner efficiently combines the advantages of MIM and ID objectives and compares favorably against previous state-of-the-art SSL models on a variety of benchmarks such as low-shot classification, long-tailed classification, clustering and semantic segmentation.
