Table of Contents
Fetching ...

An Empirical Study of Self-supervised Learning with Wasserstein Distance

Makoto Yamada, Yuki Takezawa, Guillaume Houry, Kira Michaela Dusterwald, Deborah Sulem, Han Zhao, Yao-Hung Hubert Tsai

TL;DR

This work investigates self-supervised learning using the 1-Wasserstein distance under tree metrics (Tree-Wasserstein Distance, $W_{\mathcal{T}}$). It systematically evaluates the interaction between two tree structures (Total Variation and ClusterTree) and several probability representations (Softmax, ArcFace, Simplicial Embedding), introducing a Jeffrey-divergence regularizer to stabilize optimization. Key findings show that naive Softmax with $W_{\mathcal{T}}$ performs poorly, while ArcFace with DCT/PE and Simplicial Embedding under appropriate trees yields competitive or superior results, with the Jeffrey divergence regularization providing substantial stability gains. The study demonstrates that properly combined $W_{\mathcal{T}}$-based SSL can outperform cosine-based representations on several benchmarks, though care is needed for large-class settings. Practical implications include guidance on selecting tree structures and probability models to enable stable and effective Wasserstein-based SSL.

Abstract

In this study, we delve into the problem of self-supervised learning (SSL) utilizing the 1-Wasserstein distance on a tree structure (a.k.a., Tree-Wasserstein distance (TWD)), where TWD is defined as the L1 distance between two tree-embedded vectors. In SSL methods, the cosine similarity is often utilized as an objective function; however, it has not been well studied when utilizing the Wasserstein distance. Training the Wasserstein distance is numerically challenging. Thus, this study empirically investigates a strategy for optimizing the SSL with the Wasserstein distance and finds a stable training procedure. More specifically, we evaluate the combination of two types of TWD (total variation and ClusterTree) and several probability models, including the softmax function, the ArcFace probability model, and simplicial embedding. We propose a simple yet effective Jeffrey divergence-based regularization method to stabilize optimization. Through empirical experiments on STL10, CIFAR10, CIFAR100, and SVHN, we find that a simple combination of the softmax function and TWD can obtain significantly lower results than the standard SimCLR. Moreover, a simple combination of TWD and SimSiam fails to train the model. We find that the model performance depends on the combination of TWD and probability model, and that the Jeffrey divergence regularization helps in model training. Finally, we show that the appropriate combination of the TWD and probability model outperforms cosine similarity-based representation learning.

An Empirical Study of Self-supervised Learning with Wasserstein Distance

TL;DR

This work investigates self-supervised learning using the 1-Wasserstein distance under tree metrics (Tree-Wasserstein Distance, ). It systematically evaluates the interaction between two tree structures (Total Variation and ClusterTree) and several probability representations (Softmax, ArcFace, Simplicial Embedding), introducing a Jeffrey-divergence regularizer to stabilize optimization. Key findings show that naive Softmax with performs poorly, while ArcFace with DCT/PE and Simplicial Embedding under appropriate trees yields competitive or superior results, with the Jeffrey divergence regularization providing substantial stability gains. The study demonstrates that properly combined -based SSL can outperform cosine-based representations on several benchmarks, though care is needed for large-class settings. Practical implications include guidance on selecting tree structures and probability models to enable stable and effective Wasserstein-based SSL.

Abstract

In this study, we delve into the problem of self-supervised learning (SSL) utilizing the 1-Wasserstein distance on a tree structure (a.k.a., Tree-Wasserstein distance (TWD)), where TWD is defined as the L1 distance between two tree-embedded vectors. In SSL methods, the cosine similarity is often utilized as an objective function; however, it has not been well studied when utilizing the Wasserstein distance. Training the Wasserstein distance is numerically challenging. Thus, this study empirically investigates a strategy for optimizing the SSL with the Wasserstein distance and finds a stable training procedure. More specifically, we evaluate the combination of two types of TWD (total variation and ClusterTree) and several probability models, including the softmax function, the ArcFace probability model, and simplicial embedding. We propose a simple yet effective Jeffrey divergence-based regularization method to stabilize optimization. Through empirical experiments on STL10, CIFAR10, CIFAR100, and SVHN, we find that a simple combination of the softmax function and TWD can obtain significantly lower results than the standard SimCLR. Moreover, a simple combination of TWD and SimSiam fails to train the model. We find that the model performance depends on the combination of TWD and probability model, and that the Jeffrey divergence regularization helps in model training. Finally, we show that the appropriate combination of the TWD and probability model outperforms cosine similarity-based representation learning.
Paper Structure (22 sections, 2 theorems, 28 equations, 4 figures, 5 tables)

This paper contains 22 sections, 2 theorems, 28 equations, 4 figures, 5 tables.

Key Result

Proposition 1

The robust variant of TWD (RTWD) is equivalent to total variation: where $\|{\boldsymbol{a}}-{\boldsymbol{a}}'\|_{\textnormal{TV}} = \frac{1}{2}\|{\boldsymbol{a}}-{\boldsymbol{a}}'\|_{1}$ denotes the total variation.

Figures (4)

  • Figure 1: Left tree corresponds to the total variation if we set the weight as $w_i = \frac{1}{2}, \forall i$. Right tree is a ClusterTree (2 class).
  • Figure 2: Tree for sliced Wasserstein distance for $N_{\text{leaf}} = 3$. The left figure is a chain and the right figure is the tree representation with internal nodes for the chain ($w_4 = w_5 = w_6 = 0$).
  • Figure 3: InfoNCE loss and Top-1 (Train) comparisons on STL10 dataset.
  • Figure 4: TWD loss for SimSiam models.

Theorems & Definitions (4)

  • Proposition 1
  • Proposition 2
  • proof
  • proof