Table of Contents
Fetching ...

Contrasting Multiple Representations with the Multi-Marginal Matching Gap

Zoe Piran, Michal Klein, James Thornton, Marco Cuturi

TL;DR

The paper tackles the challenge of learning coherent representations from more than two views by introducing the multi-marginal matching gap (M3G), a holistic, entropy-regularized multi-marginal OT loss that contrasts ground-truth $k$-view polymatchings with the MM-OT optimum. Central to M3G is constructing a ground-truth polymatching tensor $\mathscr{J}_{n,k}$ and a multiway cost tensor $\mathscr{M}_c(\mathscr{X})$ from $k$-view embeddings, then measuring their optimality gap via the MM-Sinkhorn algorithm with regularization $\varepsilon$. The framework is complemented by a gradient derivation through Fenchel-Young losses, enabling end-to-end learning with a single MM-S forward pass. Empirically, M3G yields improvements over pairwise extensions in ImageNet-1k SSL, DomainNet domain adaptation, and EEG multichannel data, while incurring tractable overhead at moderate $n$ and $k$. Overall, M3G establishes a first holistic approach to jointly align multiple views, with clear scalability and practical benefits for multimodal and multiview tasks.

Abstract

Learning meaningful representations of complex objects that can be seen through multiple ($k\geq 3$) views or modalities is a core task in machine learning. Existing methods use losses originally intended for paired views, and extend them to $k$ views, either by instantiating $\tfrac12k(k-1)$ loss-pairs, or by using reduced embeddings, following a \textit{one vs. average-of-rest} strategy. We propose the multi-marginal matching gap (M3G), a loss that borrows tools from multi-marginal optimal transport (MM-OT) theory to simultaneously incorporate all $k$ views. Given a batch of $n$ points, each seen as a $k$-tuple of views subsequently transformed into $k$ embeddings, our loss contrasts the cost of matching these $n$ ground-truth $k$-tuples with the MM-OT polymatching cost, which seeks $n$ optimally arranged $k$-tuples chosen within these $n\times k$ vectors. While the exponential complexity $O(n^k$) of the MM-OT problem may seem daunting, we show in experiments that a suitable generalization of the Sinkhorn algorithm for that problem can scale to, e.g., $k=3\sim 6$ views using mini-batches of size $64~\sim128$. Our experiments demonstrate improved performance over multiview extensions of pairwise losses, for both self-supervised and multimodal tasks.

Contrasting Multiple Representations with the Multi-Marginal Matching Gap

TL;DR

The paper tackles the challenge of learning coherent representations from more than two views by introducing the multi-marginal matching gap (M3G), a holistic, entropy-regularized multi-marginal OT loss that contrasts ground-truth -view polymatchings with the MM-OT optimum. Central to M3G is constructing a ground-truth polymatching tensor and a multiway cost tensor from -view embeddings, then measuring their optimality gap via the MM-Sinkhorn algorithm with regularization . The framework is complemented by a gradient derivation through Fenchel-Young losses, enabling end-to-end learning with a single MM-S forward pass. Empirically, M3G yields improvements over pairwise extensions in ImageNet-1k SSL, DomainNet domain adaptation, and EEG multichannel data, while incurring tractable overhead at moderate and . Overall, M3G establishes a first holistic approach to jointly align multiple views, with clear scalability and practical benefits for multimodal and multiview tasks.

Abstract

Learning meaningful representations of complex objects that can be seen through multiple () views or modalities is a core task in machine learning. Existing methods use losses originally intended for paired views, and extend them to views, either by instantiating loss-pairs, or by using reduced embeddings, following a \textit{one vs. average-of-rest} strategy. We propose the multi-marginal matching gap (M3G), a loss that borrows tools from multi-marginal optimal transport (MM-OT) theory to simultaneously incorporate all views. Given a batch of points, each seen as a -tuple of views subsequently transformed into embeddings, our loss contrasts the cost of matching these ground-truth -tuples with the MM-OT polymatching cost, which seeks optimally arranged -tuples chosen within these vectors. While the exponential complexity ) of the MM-OT problem may seem daunting, we show in experiments that a suitable generalization of the Sinkhorn algorithm for that problem can scale to, e.g., views using mini-batches of size . Our experiments demonstrate improved performance over multiview extensions of pairwise losses, for both self-supervised and multimodal tasks.
Paper Structure (32 sections, 1 theorem, 23 equations, 6 figures, 4 tables, 2 algorithms)

This paper contains 32 sections, 1 theorem, 23 equations, 6 figures, 4 tables, 2 algorithms.

Key Result

Proposition 3.2

The $\mathrm{M3G}$ loss is non-negative. The gradient of the $\mathrm{M3G}$ losses only requires applying the vector-Jacobian operator blondel2024elements of $\mathscr{M}$, $\partial\,\mathscr{M}(\cdot)^*[\cdot]$, evaluated at $\mathscr{X}$, to the difference of two polystochastic tensors, the groun

Figures (6)

  • Figure 1: (left) Embeddings for $n=4$ points (identified using 4 colors), each given in $k=3$ views (differentiated using 3 shapes) in $d=2$ dimensions. The ground-truth polymatching of these points is known: to each color its 3 shapes, as illustrated with colored cliques, and described mathematically as a tensor $\mathscr{J}_{4,3}$. Their initial configuration in space indicates, assuming one solves a multi-marginal optimal transport problem parameterized with the cost tensor $\mathscr{C}(\mathscr{X})$, a different polymatching $\mathscr{P}_0(\mathscr{C}(\mathscr{X}))$. That difference (quantified as a difference in their matching objectives) defines the $\mathrm{M3G}$ loss (see Def. \ref{['def:m3g']} for a precise definition of what $c,\varepsilon$ refer to). A high $\mathrm{M3G}$ indicates, as shown on the left, a large discrepancy between the ground-truth matching's cost and that of the optimal polymatching. This loss will gradually displace points so that, ideally, upon convergence and after consecutive updates (visualized in (middle) and (right) plots), both ground-truth and optimal polymatchings coincide in their objective. For additional intuition see Animation https://mlrapp.github.io/m3g/, presenting the gradient flow of $\mathrm{M3G}$ over a toy problem.
  • Figure 2: Pairwise domain prediction accuracy on the DomainNet dataset. The prediction accuracy over the unseen domain using a linear classifier trained on a single domain in the pre-training train set. Each table presents the performance of a different model choice. From left to right, baseline approaches, using the pairwise losses to evaluate one vs. average-of-rest, $\rm{BYOL}_{\rm{ave}}$ (left) and $\rm{InfoNCE}_{\rm{ave}}$ (center), compared to $\mathrm{M3G}$ (right). Columns correspond to the unseen domains and rows to the domains used for the linear classifier training. Mean performance reported for four independent repetitions.
  • Figure 3: Linear performance on ImageNet-1k as a function of the entropic regularizer $\varepsilon$. We report the linear top-1 accuracy for different values of the MM-OT entropic regularizer $\varepsilon$, as we vary the view multiplicity, $k$. All results are given for the same batch size ($n=64$) and training duration ($300$ epochs). Solid line and band depict the mean and $95\%$ confidence interval over five independent repetitions.
  • Figure 4: Robustness to cost function. Classification performance of $\mathrm{M3G}$ models pre-trained on ImageNet-1k, using either $c_{\rm{cv}}$, Eq. (\ref{['eq:cv']}) or $c_{\rm{csd}}$, Eq. (\ref{['eq:csd']}), with $\varepsilon=0.2$.
  • Figure 5: Compute overhead incurred by $\mathrm{M3G}$ on ImageNet-1k as a function of $k$. All results are given for the same per GPU batch size ($n=64$), $300$ epochs, $\varepsilon=0.2$ for $\mathrm{M3G}$, run on 4 nodes of 8 A100 GPUs.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Definition 3.1
  • Proposition 3.2