Foundation Model's Embedded Representations May Detect Distribution Shift
Max Vargas, Adam Tsou, Andrew Engel, Tony Chiang
TL;DR
This work addresses distribution shifts between Sentiment140's automatically labeled training set $P$ and manually labeled test set $M$ in the context of transfer learning with foundation-model representations. It introduces a PCA-based, data-centric method to detect shifts in final-layer embeddings and compares training regimes—full fine-tuning on $P$ and linear probing on $M$—to assess generalization. Key findings show that many foundation-model embeddings separate $P$ and $M$, and fine-tuning on $P$ can degrade performance on $M$, while linear probes using pre-trained features on $M$ offer robust, data-efficient generalization. The study underscores the need to match train/test populations, advocates cautious pre-processing before TL, and suggests avenues for quantitative, architecture-aware analysis of distribution shifts.
Abstract
Sampling biases can cause distribution shifts between train and test datasets for supervised learning tasks, obscuring our ability to understand the generalization capacity of a model. This is especially important considering the wide adoption of pre-trained foundational neural networks -- whose behavior remains poorly understood -- for transfer learning (TL) tasks. We present a case study for TL on the Sentiment140 dataset and show that many pre-trained foundation models encode different representations of Sentiment140's manually curated test set $M$ from the automatically labeled training set $P$, confirming that a distribution shift has occurred. We argue training on $P$ and measuring performance on $M$ is a biased measure of generalization. Experiments on pre-trained GPT-2 show that the features learnable from $P$ do not improve (and in fact hamper) performance on $M$. Linear probes on pre-trained GPT-2's representations are robust and may even outperform overall fine-tuning, implying a fundamental importance for discerning distribution shift in train/test splits for model interpretation.
