Table of Contents
Fetching ...

An Analytical Model for Overparameterized Learning Under Class Imbalance

Eliav Mor, Yair Carmon

TL;DR

A tight, closed form approximation is developed for the test error of several practical learning methods, including logit adjustment and class dependent temperature, in a high-dimensional Gaussian mixture model.

Abstract

We study class-imbalanced linear classification in a high-dimensional Gaussian mixture model. We develop a tight, closed form approximation for the test error of several practical learning methods, including logit adjustment and class dependent temperature. Our approximation allows us to analytically tune and compare these methods, highlighting how and when they overcome the pitfalls of standard cross-entropy minimization. We test our theoretical findings on simulated data and imbalanced CIFAR10, MNIST and FashionMNIST datasets.

An Analytical Model for Overparameterized Learning Under Class Imbalance

TL;DR

A tight, closed form approximation is developed for the test error of several practical learning methods, including logit adjustment and class dependent temperature, in a high-dimensional Gaussian mixture model.

Abstract

We study class-imbalanced linear classification in a high-dimensional Gaussian mixture model. We develop a tight, closed form approximation for the test error of several practical learning methods, including logit adjustment and class dependent temperature. Our approximation allows us to analytically tune and compare these methods, highlighting how and when they overcome the pitfalls of standard cross-entropy minimization. We test our theoretical findings on simulated data and imbalanced CIFAR10, MNIST and FashionMNIST datasets.

Paper Structure

This paper contains 91 sections, 34 theorems, 297 equations, 18 figures.

Key Result

Theorem 4.1

Let $\tilde{W}^{\rm{ma}}, \tilde{b}^{\rm{ma}}$ be the expected kernel approximation for the MA predictor opt:ma-constraint with any $\delta$ that satisfies $\delta_i > 0$ for all $i \in [c]$. Then, for all $y\in[c]$ we have $\tilde{w}^{\rm{ma}}_y = \sum_{i=1}^c \tilde{\alpha}_{y[i]}^{\rm{ma}} \bar{x using $\xi_i \coloneqq \|\mu_{i}\|^2 + \frac{\sigma^2 d}{N_{i}}$ and $M\coloneqq \sum_{i=1}^{c} \fr

Figures (18)

  • Figure 1: Illustration of our main findings. We plot the test error of the worst-performing class as a function of model dimension, for the different learning methods we consider. The shaded areas indicate empirical measurements and the solid lines show our analytical approximation prediction. Each panel shows a different set of model parameters; see \ref{['appendix:fig1']} for detailed description.
  • Figure 2: Worst class error, balanced error, and macro $F_1$ score vs. loss tuning hyperparameter for the different methods we consider and linear classification on synthetic data from our model. The shaded region shows empirical results (two standard deviations over 5 random seeds), and the solid lines show our analytical approximation.
  • Figure 3: Empirical worst class test error vs. loss tuning hyperparameter for the different methods we consider and kernel classification on two imbalanced versions of CIFAR10. The markers show the minimum error for each instance.
  • Figure 4: Empirical worst class test error vs. loss tuning hyperparameter for the different methods we consider and neural network fine-tuning on class-imbalanced subsets of CIFAR10.
  • Figure 5: Per-class error vs. imbalance ratio for MM predictors. The $x$-axis shows the imbalance ratio, i.e., the ratio between the largest and smallest per-class sample sizes. The $y$-axis represents the test error, and we plot the test error of each class. The left side shows the test error of the maximum margin (MM) predictor without bias, while the right side shows the test error of the MM predictor with bias. The shaded regions show empirical results (two standard deviations over 5 random seeds) and the solid lines show our analytical approximation.
  • ...and 13 more figures

Theorems & Definitions (63)

  • Remark 2.1: Limitations of CDT
  • Theorem 4.1
  • Proposition 4.1
  • Remark 4.2
  • Proposition 4.2
  • Theorem 4.3
  • Proposition 4.3
  • Lemma A.1
  • proof
  • Lemma A.1
  • ...and 53 more