Table of Contents
Fetching ...

Your contrastive learning problem is secretly a distribution alignment problem

Zihao Chen, Chi-Heng Lin, Ran Liu, Jingyun Xiao, Eva L Dyer

TL;DR

This work reframes contrastive learning as a distribution alignment problem by casting it as a transport problem between augmented views. It introduces Generalized Contrastive Alignment (GCA), which uses target transport plans and proximal-point updates to flexibly encode matching constraints via (unbalanced) optimal transport, improving alignment and uniformity of learned representations. The authors establish connections between GCA and existing CL objectives (INCE, RINCE, BYOL), provide convergence and complexity analysis, and demonstrate empirical gains on standard and corrupted augmentations as well as domain generalization, with GCA-UOT often delivering the best performance. The framework enables incorporating domain knowledge and robust alignment strategies into self-supervised learning, suggesting broad applicability beyond standard image domains. Theoretical results link iterative GCA updates to tighter alignment bounds and higher latent-space uniformity, which translate into improved downstream classification performance. Overall, GCA offers a principled, scalable path to more robust, domain-aware self-supervised representations.

Abstract

Despite the success of contrastive learning (CL) in vision and language, its theoretical foundations and mechanisms for building representations remain poorly understood. In this work, we build connections between noise contrastive estimation losses widely used in CL and distribution alignment with entropic optimal transport (OT). This connection allows us to develop a family of different losses and multistep iterative variants for existing CL methods. Intuitively, by using more information from the distribution of latents, our approach allows a more distribution-aware manipulation of the relationships within augmented sample sets. We provide theoretical insights and experimental evidence demonstrating the benefits of our approach for {\em generalized contrastive alignment}. Through this framework, it is possible to leverage tools in OT to build unbalanced losses to handle noisy views and customize the representation space by changing the constraints on alignment. By reframing contrastive learning as an alignment problem and leveraging existing optimization tools for OT, our work provides new insights and connections between different self-supervised learning models in addition to new tools that can be more easily adapted to incorporate domain knowledge into learning.

Your contrastive learning problem is secretly a distribution alignment problem

TL;DR

This work reframes contrastive learning as a distribution alignment problem by casting it as a transport problem between augmented views. It introduces Generalized Contrastive Alignment (GCA), which uses target transport plans and proximal-point updates to flexibly encode matching constraints via (unbalanced) optimal transport, improving alignment and uniformity of learned representations. The authors establish connections between GCA and existing CL objectives (INCE, RINCE, BYOL), provide convergence and complexity analysis, and demonstrate empirical gains on standard and corrupted augmentations as well as domain generalization, with GCA-UOT often delivering the best performance. The framework enables incorporating domain knowledge and robust alignment strategies into self-supervised learning, suggesting broad applicability beyond standard image domains. Theoretical results link iterative GCA updates to tighter alignment bounds and higher latent-space uniformity, which translate into improved downstream classification performance. Overall, GCA offers a principled, scalable path to more robust, domain-aware self-supervised representations.

Abstract

Despite the success of contrastive learning (CL) in vision and language, its theoretical foundations and mechanisms for building representations remain poorly understood. In this work, we build connections between noise contrastive estimation losses widely used in CL and distribution alignment with entropic optimal transport (OT). This connection allows us to develop a family of different losses and multistep iterative variants for existing CL methods. Intuitively, by using more information from the distribution of latents, our approach allows a more distribution-aware manipulation of the relationships within augmented sample sets. We provide theoretical insights and experimental evidence demonstrating the benefits of our approach for {\em generalized contrastive alignment}. Through this framework, it is possible to leverage tools in OT to build unbalanced losses to handle noisy views and customize the representation space by changing the constraints on alignment. By reframing contrastive learning as an alignment problem and leveraging existing optimization tools for OT, our work provides new insights and connections between different self-supervised learning models in addition to new tools that can be more easily adapted to incorporate domain knowledge into learning.

Paper Structure

This paper contains 83 sections, 22 theorems, 145 equations, 8 figures, 6 tables, 2 algorithms.

Key Result

Theorem 1

Let ${\bf K}_\theta$ denote the augmentation kernel as in Definition def:gibbs with cosine similarity, $d_\Gamma$ and $d_M$ equal to KL-divergence, and constraint set as $C_1^\mu$ in Equation eq:Birkhoff. The INCE objective in Equation eq:INCE can be re-expressed as a GCA problem in Equation eq:main

Figures (8)

  • Figure 1: Incorporating different priors into learning across multiple domains. (A) Example target alignment plan ${\bf P}_{\text{t}gt}$, where the target over all samples from the same domain are set to $\alpha$, the diagonal values are set to 1, and across-domain samples are set to $\beta$. (B) The domain classification accuracy (red) and overall class accuracy (blue) with ($\alpha-\beta$) increases.
  • Figure A1: Illustration of the proximal operators A. Visualization of proximal operators in $\mathbb{R}^3$. On the surface defined by $h(x, y) = x^2 + y^2$ within the domain constraints $-1.2 < x < 1.2$ and $-1.2 < y < 1.2$. If $v=v_1=(0.76, 0.76, 1.16)$, it lies within the domain of $h$, represented on the surface at the exact location matching its third coordinate with $h(x, y)$. If $v=v_2=(1.5, 1.5, 6)$, which is outside the feasible region defined by $h$, the proximal operator projects it to the closest point within the domain, resulting in $v_2$'s projection to approximately $(0.85, 0.85, 1.45)$. B. Visualization of proximal operators in $\mathbb{R}^2$. The blue dashed line represents the function $h(x) = x^2$. The orange dash-dotted line illustrates the penalty term $\frac{1}{2} \|x - v\|^2$ with $v = (2, 0)$, indicating the squared distance from any $x$ to $v$. The green solid line is the proximal operator $2x^2 + \frac{1}{2} \|x - v\|^2$, which gets close to the minimization point of h(x) from $v$. The red point marks the $\text{Prox}_h(v)$ in this space.
  • Figure A2: Time complexity analysis (A) Time complexity analysis of different methods. Here, we provide the time complexity for different contrastive methods (INCE, RINCE) and GCA-based methods (GCA-INCE, GCA-RINCE, and GCA-UOT) on CIFAR-10. (B) Time complexity for INCE (GCA-INCE-1), and GCA-INCE with different number of iterations GCA-INCE-100 denotes GCA-INCE with 100 iterations. We ran the methods on the CIFAR-10 as self-supervised learning task for 50 epochs, and compared their run time. (C) Performance of the INCE (iteration=1) and GCA-INCE (iterations>1) on the CIFAR10 with different number of iterations. The shaded blue region is the standard deviation across 5 seeds.
  • Figure A3: Alignment and uniformity metrics on CIFAR-10. To visualize the ability of uniformity and alignment with different methods under different augmentation settings (C0: standard, C1: erase, C2: crop, C3: brightness). The bar above the x axis (zero line) represents the alignment loss, while the bar under the x axis represents the uniformity loss. The shorter the color bars i.e with lower alignment loss and higher uniformity loss, correspond to the better performance of SSL models.
  • Figure A4: Comparison of the $-\log(\mathbf{P})$ matrix across different methods. (A) The INCE matrix with row normalization. (B) The $-\log(\mathbf{P})$ matrix of GCA-INCE with five iterations in forward pass, both row and column normalization. (C) The $-\log(\mathbf{P})$ matrix of GCA-RINCE with five iterations in forward pass. (D) The $-\log(\mathbf{P})$ matrix of GCA-UOT with five iterations in forward pass
  • ...and 3 more figures

Theorems & Definitions (35)

  • Definition 1: Proximal Operator
  • Definition 2: Wasserstein Dependency Measure
  • Definition 3: Augmentation Kernel
  • Theorem 1: INCE Equivalence
  • Theorem 2: RINCE Equivalence
  • Theorem 3: W1 Equivalence
  • Theorem 4: BYOL Equivalence
  • Theorem 5: Improved Alignment with INCE
  • Theorem 6: Improved Alignment with RINCE
  • Theorem 7: Improved Uniformity
  • ...and 25 more