Table of Contents
Fetching ...

Unsupervised Representation Learning by Balanced Self Attention Matching

Daniel Shalam, Simon Korman

TL;DR

BAM presents a self-supervised learning framework that learns representations by aligning self-attention distributions across augmented views, rather than directly matching instance features. By constructing a global, entropy-regularized target via optimal transport (Sinkhorn) to balance the self-attention matrix $A$, and by suppressing positive-pair dominance through negation of augmentation blocks, BAM mitigates feature collapse while leveraging rich in-batch relationships. The approach yields competitive results on ImageNet linear probing and fine-tuning, strong semi-supervised performance, and robust transfer to video segmentation and detection tasks, all without memory banks or multiple encoders. Overall, BAM demonstrates that focusing on in-batch self-attention statistics provides a powerful, scalable path for unsupervised representation learning with broad applicability.

Abstract

Many leading self-supervised methods for unsupervised representation learning, in particular those for embedding image features, are built on variants of the instance discrimination task, whose optimization is known to be prone to instabilities that can lead to feature collapse. Different techniques have been devised to circumvent this issue, including the use of negative pairs with different contrastive losses, the use of external memory banks, and breaking of symmetry by using separate encoding networks with possibly different structures. Our method, termed BAM, rather than directly matching features of different views (augmentations) of input images, is based on matching their self-attention vectors, which are the distributions of similarities to the entire set of augmented images of a batch. We obtain rich representations and avoid feature collapse by minimizing a loss that matches these distributions to their globally balanced and entropy regularized version, which is obtained through a simple self-optimal-transport computation. We ablate and verify our method through a wide set of experiments that show competitive performance with leading methods on both semi-supervised and transfer-learning benchmarks. Our implementation and pre-trained models are available at github.com/DanielShalam/BAM .

Unsupervised Representation Learning by Balanced Self Attention Matching

TL;DR

BAM presents a self-supervised learning framework that learns representations by aligning self-attention distributions across augmented views, rather than directly matching instance features. By constructing a global, entropy-regularized target via optimal transport (Sinkhorn) to balance the self-attention matrix , and by suppressing positive-pair dominance through negation of augmentation blocks, BAM mitigates feature collapse while leveraging rich in-batch relationships. The approach yields competitive results on ImageNet linear probing and fine-tuning, strong semi-supervised performance, and robust transfer to video segmentation and detection tasks, all without memory banks or multiple encoders. Overall, BAM demonstrates that focusing on in-batch self-attention statistics provides a powerful, scalable path for unsupervised representation learning with broad applicability.

Abstract

Many leading self-supervised methods for unsupervised representation learning, in particular those for embedding image features, are built on variants of the instance discrimination task, whose optimization is known to be prone to instabilities that can lead to feature collapse. Different techniques have been devised to circumvent this issue, including the use of negative pairs with different contrastive losses, the use of external memory banks, and breaking of symmetry by using separate encoding networks with possibly different structures. Our method, termed BAM, rather than directly matching features of different views (augmentations) of input images, is based on matching their self-attention vectors, which are the distributions of similarities to the entire set of augmented images of a batch. We obtain rich representations and avoid feature collapse by minimizing a loss that matches these distributions to their globally balanced and entropy regularized version, which is obtained through a simple self-optimal-transport computation. We ablate and verify our method through a wide set of experiments that show competitive performance with leading methods on both semi-supervised and transfer-learning benchmarks. Our implementation and pre-trained models are available at github.com/DanielShalam/BAM .
Paper Structure (48 sections, 5 equations, 7 figures, 10 tables)

This paper contains 48 sections, 5 equations, 7 figures, 10 tables.

Figures (7)

  • Figure 1: Distributions of pairwise similarities in an augmented batch of images. In 'instance-discrimination' self-supervised learning, pairs of instances (left) are typically categorized as "positives", which are pulled together in most approaches, or "negatives", which are pushed apart by contrastive approaches and totally ignored by others (momentum/distillation based). Our claim is that such simplistic same/not-same interpretations do not fully exploit the rich relative metric information that resides in the statistics of the entire set of pairwise similarities. For example, the categorization of pairs into "same-class" and "diff-class" (according to their class labels, which are unknown), shows (right) that while the similarities (soft-max entries) of "pos" pairs (yellow) are well separated from those of "diff" pairs (blue), the "non-aug" pairs similarities highly overlap the others. In our approach, the latents of a positive pair are pushed together by matching their self-attention distributions (which are the probabilities of the "neg" pairs they each belong to). See Fig. \ref{['fig.paradigms']} and text for further details.
  • Figure 2: Paradigms of instance-discrimination based self-supervised learning - demonstrated on a batch of 3 images $x_1,x_2,x_3$ with 2 augmentations. (a): Each image $x_i$ is transformed into a pair of views by a pair of random augmentations. Both views are mapped to the latent space by a joint (learned) embedding. (b): Contrastive-learning methods compare each latent's attention distribution (marked by pink arc) to the 1-hot distribution of positive (green) and negative (red) matches. (c): Distillation methods focus on only positive pairs, by matching their latents under the online/offline encoders. (d): Our approach suggests matching between the self-attention (SA) distributions of positive latent pairs. The SA distribution of one (pink arc) is matched to the balanced SA distribution of the other (blue arc). See text for detail.
  • Figure 3: Evolution of entropy at training.
  • Figure 4: Code snippet for model training given a batch of images: The $k=2$ augmentations case.
  • Figure 5: Self-attention visualizations from the last layer of a BAM ViT-B network, that was pre-trained on the unlabeled ImageNet train-set. The three example images are from the validation set. To the right of each image is the attention map, which is the average of the attention maps produced by each of the 12 [cls] token heads, shown on the right.
  • ...and 2 more figures