Contrasting Multiple Representations with the Multi-Marginal Matching Gap
Zoe Piran, Michal Klein, James Thornton, Marco Cuturi
TL;DR
The paper tackles the challenge of learning coherent representations from more than two views by introducing the multi-marginal matching gap (M3G), a holistic, entropy-regularized multi-marginal OT loss that contrasts ground-truth $k$-view polymatchings with the MM-OT optimum. Central to M3G is constructing a ground-truth polymatching tensor $\mathscr{J}_{n,k}$ and a multiway cost tensor $\mathscr{M}_c(\mathscr{X})$ from $k$-view embeddings, then measuring their optimality gap via the MM-Sinkhorn algorithm with regularization $\varepsilon$. The framework is complemented by a gradient derivation through Fenchel-Young losses, enabling end-to-end learning with a single MM-S forward pass. Empirically, M3G yields improvements over pairwise extensions in ImageNet-1k SSL, DomainNet domain adaptation, and EEG multichannel data, while incurring tractable overhead at moderate $n$ and $k$. Overall, M3G establishes a first holistic approach to jointly align multiple views, with clear scalability and practical benefits for multimodal and multiview tasks.
Abstract
Learning meaningful representations of complex objects that can be seen through multiple ($k\geq 3$) views or modalities is a core task in machine learning. Existing methods use losses originally intended for paired views, and extend them to $k$ views, either by instantiating $\tfrac12k(k-1)$ loss-pairs, or by using reduced embeddings, following a \textit{one vs. average-of-rest} strategy. We propose the multi-marginal matching gap (M3G), a loss that borrows tools from multi-marginal optimal transport (MM-OT) theory to simultaneously incorporate all $k$ views. Given a batch of $n$ points, each seen as a $k$-tuple of views subsequently transformed into $k$ embeddings, our loss contrasts the cost of matching these $n$ ground-truth $k$-tuples with the MM-OT polymatching cost, which seeks $n$ optimally arranged $k$-tuples chosen within these $n\times k$ vectors. While the exponential complexity $O(n^k$) of the MM-OT problem may seem daunting, we show in experiments that a suitable generalization of the Sinkhorn algorithm for that problem can scale to, e.g., $k=3\sim 6$ views using mini-batches of size $64~\sim128$. Our experiments demonstrate improved performance over multiview extensions of pairwise losses, for both self-supervised and multimodal tasks.
