Table of Contents
Fetching ...

Improving OOD Generalization of Pre-trained Encoders via Aligned Embedding-Space Ensembles

Shuman Peng, Arash Khoeini, Sharan Vaswani, Martin Ester

TL;DR

This work first performs a theoretical analysis that reveals the relationship between individual hyperspherical embedding spaces in an ensemble, and designs a principled method to align these embedding spaces in an unsupervised manner that improves pre-trained embedding quality on in-distribution and OOD data compared to single encoders.

Abstract

The quality of self-supervised pre-trained embeddings on out-of-distribution (OOD) data is poor without fine-tuning. A straightforward and simple approach to improving the generalization of pre-trained representation to OOD data is the use of deep ensembles. However, obtaining an effective ensemble in the embedding space with only unlabeled data remains an unsolved problem. We first perform a theoretical analysis that reveals the relationship between individual hyperspherical embedding spaces in an ensemble. We then design a principled method to align these embedding spaces in an unsupervised manner. Experimental results on the MNIST dataset show that our embedding-space ensemble method improves pre-trained embedding quality on in-distribution and OOD data compared to single encoders.

Improving OOD Generalization of Pre-trained Encoders via Aligned Embedding-Space Ensembles

TL;DR

This work first performs a theoretical analysis that reveals the relationship between individual hyperspherical embedding spaces in an ensemble, and designs a principled method to align these embedding spaces in an unsupervised manner that improves pre-trained embedding quality on in-distribution and OOD data compared to single encoders.

Abstract

The quality of self-supervised pre-trained embeddings on out-of-distribution (OOD) data is poor without fine-tuning. A straightforward and simple approach to improving the generalization of pre-trained representation to OOD data is the use of deep ensembles. However, obtaining an effective ensemble in the embedding space with only unlabeled data remains an unsolved problem. We first perform a theoretical analysis that reveals the relationship between individual hyperspherical embedding spaces in an ensemble. We then design a principled method to align these embedding spaces in an unsupervised manner. Experimental results on the MNIST dataset show that our embedding-space ensemble method improves pre-trained embedding quality on in-distribution and OOD data compared to single encoders.

Paper Structure

This paper contains 39 sections, 2 theorems, 6 equations, 8 figures, 2 tables.

Key Result

Proposition 1

Under the above assumption, $f_1$ and $f_2$ learn the same latents up to an orthogonal transformation $R$, that is, $f_1(x) = R f_2(x)$.

Figures (8)

  • Figure 1: Need for embedding alignment: The ensemble mean of two different embeddings (yellow, blue) in misaligned embedding spaces $Z_1, Z_2$ collapses to the same vector in $\bar{Z}$, although they have different semantic meanings.
  • Figure 2: Comparing embedding qualities of single models (blue), an ensemble of unaligned embedding spaces (orange), and an ensemble of aligned embedding spaces (purple) in the ID and OOD settings. Recall@1 and MAP@R are presented. Higher values indicate better performance. The mean and standard deviation (error bars) of the performance metrics are reported for the 5 single models. The ensembles do not have standard deviation since all 5 models are combined into one.
  • Figure 3: For in-distribution (ID) evaluation, images like those in (a) were used. For out-of-distribution (OOD) evaluation, images like those in (b) were used.
  • Figure 4: Supervised contrastive pre-training with Colored MNIST as OOD evaluation data. Comparing embedding qualities of single models (blue), an ensemble of unaligned embedding spaces (orange), and an ensemble of aligned embedding spaces (purple) in the ID and OOD settings. Recall@1 and MAP@R are presented. Higher values indicate better performance. The mean and standard deviation (error bars) of the performance metrics are reported for the 5 single models. The ensembles do not have standard deviation since all 5 models are combined into one.
  • Figure 5: Supervised contrastive pre-training with Cropped MNIST as OOD evaluation data. Comparing embedding qualities of single models (blue), an ensemble of unaligned embedding spaces (orange), and an ensemble of aligned embedding spaces (purple) in the ID and OOD settings. Recall@1 and MAP@R are presented. Higher values indicate better performance. The mean and standard deviation (error bars) of the performance metrics are reported for the 5 single models. The ensembles do not have standard deviation since all 5 models are combined into one.
  • ...and 3 more figures

Theorems & Definitions (4)

  • Proposition 1: Orthogonal transformation relationship
  • Proposition 2: Ensemble recovers correct latents
  • proof
  • proof