Table of Contents
Fetching ...

Theoretical Analysis of Contrastive Learning under Imbalanced Data: From Training Dynamics to a Pruning Solution

Haixu Liao, Yating Zhou, Songyang Zhang, Meng Wang, Shuai Zhang

TL;DR

This work provides a theoretical framework for contrastive learning with Transformer-MLP encoders under imbalanced data, revealing a three-stage training dynamic where feature directions are learned in a staged manner and minority features are learned with smaller magnitudes. It then shows that magnitude-based pruning can mitigate imbalance by amplifying updates along minority-feature directions, increasing the number of neurons that specialize in minority features and improving representation quality. The authors validate their theory with numerical experiments on long-tailed CIFAR and ImageNet benchmarks as well as synthetic data, demonstrating improved downstream linear-probe performance and reduced head-tail gaps when pruning is employed. While the results are derived under a simplified architectural setting, they offer conceptual and practical guidance for improving contrastive SSL under real-world imbalance, and point to future work on extending the analysis to more complex models and alternative imbalance strategies.

Abstract

Contrastive learning has emerged as a powerful framework for learning generalizable representations, yet its theoretical understanding remains limited, particularly under imbalanced data distributions that are prevalent in real-world applications. Such an imbalance can degrade representation quality and induce biased model behavior, yet a rigorous characterization of these effects is lacking. In this work, we develop a theoretical framework to analyze the training dynamics of contrastive learning with Transformer-based encoders under imbalanced data. Our results reveal that neuron weights evolve through three distinct stages of training, with different dynamics for majority features, minority features, and noise. We further show that minority features reduce representational capacity, increase the need for more complex architectures, and hinder the separation of ground-truth features from noise. Inspired by these neuron-level behaviors, we show that pruning restores performance degraded by imbalance and enhances feature separation, offering both conceptual insights and practical guidance. Major theoretical findings are validated through numerical experiments.

Theoretical Analysis of Contrastive Learning under Imbalanced Data: From Training Dynamics to a Pruning Solution

TL;DR

This work provides a theoretical framework for contrastive learning with Transformer-MLP encoders under imbalanced data, revealing a three-stage training dynamic where feature directions are learned in a staged manner and minority features are learned with smaller magnitudes. It then shows that magnitude-based pruning can mitigate imbalance by amplifying updates along minority-feature directions, increasing the number of neurons that specialize in minority features and improving representation quality. The authors validate their theory with numerical experiments on long-tailed CIFAR and ImageNet benchmarks as well as synthetic data, demonstrating improved downstream linear-probe performance and reduced head-tail gaps when pruning is employed. While the results are derived under a simplified architectural setting, they offer conceptual and practical guidance for improving contrastive SSL under real-world imbalance, and point to future work on extending the analysis to more complex models and alternative imbalance strategies.

Abstract

Contrastive learning has emerged as a powerful framework for learning generalizable representations, yet its theoretical understanding remains limited, particularly under imbalanced data distributions that are prevalent in real-world applications. Such an imbalance can degrade representation quality and induce biased model behavior, yet a rigorous characterization of these effects is lacking. In this work, we develop a theoretical framework to analyze the training dynamics of contrastive learning with Transformer-based encoders under imbalanced data. Our results reveal that neuron weights evolve through three distinct stages of training, with different dynamics for majority features, minority features, and noise. We further show that minority features reduce representational capacity, increase the need for more complex architectures, and hinder the separation of ground-truth features from noise. Inspired by these neuron-level behaviors, we show that pruning restores performance degraded by imbalance and enhances feature separation, offering both conceptual insights and practical guidance. Major theoretical findings are validated through numerical experiments.
Paper Structure (82 sections, 30 theorems, 438 equations, 6 figures, 2 tables, 1 algorithm)

This paper contains 82 sections, 30 theorems, 438 equations, 6 figures, 2 tables, 1 algorithm.

Key Result

Lemma 3.1

During the first training stage, the update of neuron weights $\bm{w}_{i}^{(t)}$ can be bounded for all $t \in [0, T_1]$ as follows, where $C_z$ denotes positive constants and $T_1 = \Theta\!\left( \frac{d_1 \log d}{\eta \log \log d} \right)$.

Figures (6)

  • Figure 1: Neuron projection dynamics over training epochs. The blue curve shows the growth of a neuron’s projection onto its dominant feature, the orange curve shows the projection onto a non-dominant feature, and the green curve shows the projection onto the noise space direction (which remains larger than the projections onto other features). In the first stage, the neuron grows mainly along feature directions while suppressing noise. In the second stage, the projection onto the dominant feature grows faster than all other features, creating clear separation. In the third stage, as training approaches $T_3$, the neuron converges, and its final representation is dominated by the learned feature, with negligible components in other directions.
  • Figure 2: Number of neurons with $\frac{|\langle w_i, M_j \rangle|}{\|w_i\|\|M_j\|} \geq 0.3$ vs $\varepsilon_{\min}$ for different NSR values.
  • Figure 3: Maximum $\frac{|\langle w_i, M_j \rangle|}{\|w_i\|\|M_j\|}$ vs $\varepsilon_{\min}$ for different NSR values.
  • Figure 4: $\frac{1}{N}\sum_{n=1}^{N}\frac{\langle f(X_n), f(Y_n)\rangle}{\|f(X_n)\|\|f(Y_n)\|}$ vs $\varepsilon_{\min}$ for different NSR values.
  • Figure 5: Downstream regression task: Test MSE vs $\varepsilon_{\min}$ for different NSR values.
  • ...and 1 more figures

Theorems & Definitions (77)

  • Definition 3.1: Majority and minority features
  • Lemma 3.1: Stage 1
  • Lemma 3.2: Stage 2
  • Theorem 3.1: Stage 3: Convergence
  • Theorem 3.2: Pruning: Reinforcing Minority Feature Learning
  • Lemma B.1: Approximation of empirical gradients by population gradients
  • Definition B.1: Characterization of Neurons
  • Lemma B.2
  • Theorem C.1: Initial feature decoupling
  • Lemma C.1
  • ...and 67 more