Table of Contents
Fetching ...

Bias in Motion: Theoretical Insights into the Dynamics of Bias in SGD Training

Anchit Jain, Rozhin Nobahari, Aristide Baratin, Stefano Sarao Mannelli

TL;DR

The paper tackles the problem of understanding how bias emerges and evolves during SGD training, particularly during transient, non-asymptotic phases. It introduces a teacher-mixture with Gaussian sub-populations and analyzes online SGD for a linear classifier in the high-dimensional limit, deriving a solvable system of ordinary differential equations that track a small set of order parameters. In the two-cluster case, the authors obtain explicit closed-form solutions for the order parameters, revealing three distinct learning phases and multiple timescales that drive spurious correlations and fairness-related bias, supported by extensive simulations on synthetic data and real datasets (e.g., CIFAR10, MNIST, CelebA). The work provides a unifying, theory-grounded view of bias generation that connects fairness and spurious-correlation phenomena and yields practical insights into how representation, variance, and learning rate shape transient bias. Overall, the results offer a principled framework to anticipate and mitigate bias during training and motivate dynamical approaches to fairness-aware learning in realistic, resource-constrained settings.

Abstract

Machine learning systems often acquire biases by leveraging undesired features in the data, impacting accuracy variably across different sub-populations. Current understanding of bias formation mostly focuses on the initial and final stages of learning, leaving a gap in knowledge regarding the transient dynamics. To address this gap, this paper explores the evolution of bias in a teacher-student setup modeling different data sub-populations with a Gaussian-mixture model. We provide an analytical description of the stochastic gradient descent dynamics of a linear classifier in this setting, which we prove to be exact in high dimension. Notably, our analysis reveals how different properties of sub-populations influence bias at different timescales, showing a shifting preference of the classifier during training. Applying our findings to fairness and robustness, we delineate how and when heterogeneous data and spurious features can generate and amplify bias. We empirically validate our results in more complex scenarios by training deeper networks on synthetic and real datasets, including CIFAR10, MNIST, and CelebA.

Bias in Motion: Theoretical Insights into the Dynamics of Bias in SGD Training

TL;DR

The paper tackles the problem of understanding how bias emerges and evolves during SGD training, particularly during transient, non-asymptotic phases. It introduces a teacher-mixture with Gaussian sub-populations and analyzes online SGD for a linear classifier in the high-dimensional limit, deriving a solvable system of ordinary differential equations that track a small set of order parameters. In the two-cluster case, the authors obtain explicit closed-form solutions for the order parameters, revealing three distinct learning phases and multiple timescales that drive spurious correlations and fairness-related bias, supported by extensive simulations on synthetic data and real datasets (e.g., CIFAR10, MNIST, CelebA). The work provides a unifying, theory-grounded view of bias generation that connects fairness and spurious-correlation phenomena and yields practical insights into how representation, variance, and learning rate shape transient bias. Overall, the results offer a principled framework to anticipate and mitigate bias during training and motivate dynamical approaches to fairness-aware learning in realistic, resource-constrained settings.

Abstract

Machine learning systems often acquire biases by leveraging undesired features in the data, impacting accuracy variably across different sub-populations. Current understanding of bias formation mostly focuses on the initial and final stages of learning, leaving a gap in knowledge regarding the transient dynamics. To address this gap, this paper explores the evolution of bias in a teacher-student setup modeling different data sub-populations with a Gaussian-mixture model. We provide an analytical description of the stochastic gradient descent dynamics of a linear classifier in this setting, which we prove to be exact in high dimension. Notably, our analysis reveals how different properties of sub-populations influence bias at different timescales, showing a shifting preference of the classifier during training. Applying our findings to fairness and robustness, we delineate how and when heterogeneous data and spurious features can generate and amplify bias. We empirically validate our results in more complex scenarios by training deeper networks on synthetic and real datasets, including CIFAR10, MNIST, and CelebA.
Paper Structure (45 sections, 4 theorems, 72 equations, 13 figures)

This paper contains 45 sections, 4 theorems, 72 equations, 13 figures.

Key Result

Lemma 3.1

The generalisation error can be written as an average $\epsilon = \sum_{j=1}^m \rho_j \epsilon_j$ over the clusters, where $\epsilon_j$ is a degree 2 polynomial in $R_j, M_j$ and $Q$ taking the form where $\alpha_j, \beta_j$ are constants independent of the parameter $\pmb w$.

Figures (13)

  • Figure 1: Teacher-Mixture in fairness and robustness.Panel (a) shows the generalisation errors---for the subpopulations $+$ (blue) and $-$ (red)---obtained through simulation (crosses) and predicted by the theory (solid lines) for a network with linear activation. The inset shows the same comparison for the order parameters: $R_+$ (blue), $R_-$ (red), $M$ (green), and $Q$ (orange). Panels (b-d) exemplify the different scenarios achievable in the TM model investigated in Sec. \ref{['sec:insights']}. Panel (b) represent a model for robustness where a spurious feature---given by the shift vector---can mislead the classifier, see Sec. \ref{['sec:spurious']}. Panels (c,d) are instead discussed in Sec. \ref{['sec:fairness']} and represent two models of fairness. First, Panel (b) has no shift, $v=0$, allowing us to remove the confounding effects. Finally, Panel (d) shows the general fairness problem.
  • Figure 2: Spurious correlations transient alignment. Time-evolution of loss (purple), student-teacher (red) and student-shift (green) cosine similarities. The initial phase (green background) of learning aligns classifier and shift vector before aligning with the teacher (red background), Sec. \ref{['sec:spurious']}. Parameters: $v= 16, \rho=0.5, \Delta_-=\Delta_+=0.1, T_\pm=1, \eta=0.5$. For these parameters, spurious features allow the correct classification of 90% of the samples.
  • Figure 3: The crossing phenomenon.Panel (a) (left side) shows the loss curves of sub-population $-$ (in red) and sub-population $+$ in blue along with the overall loss (in purple). We observe a crossing cause by a higher variance but lower representation in sub-population $-$. The background colours represent the different phases of bias that are characterised by the evolution of the order parameters shown in Panel (a) (right side). Panel (b) shows the presence of the crossing phenomenon in a large portion of the parameter space using a phase diagram. Blue indicates an asymptotic preference for sub-population $+$ and red the opposite. Dark colours indicates regions where bias is consistent across training, while regions in light colours undergo a crossing phenomenon. White indicates that learning rate was too high and training diverged. Parameters: $v=0, \Delta_+=1, T_\pm=0.9, \eta=0.1$.
  • Figure 4: Double crossing phenomenon.(Left panel) shows the loss for the two sub-populations (blue and red lines) and the global one (in purple). (Right panel) shows the value of the order parameters across time. The behaviour of the order parameters across time provides a precise characterisation and understanding of the different phases. Parameters: $v = 100, \rho = 0.75, \Delta_+=0.1, \Delta_- = 0.5, \eta = 0.03, T_\pm = 0.9, \alpha_+ = 0.343, \alpha_- = 0.12$.
  • Figure 5: Numerical simulations on MNIST. The figure shows the average (solid lines) and standard deviation (shaded area) of 100 simulations run in this framework. In particular the upper plots show the test loss and lower plots the test accuracy for subpopulation $+$ (blue) and $-$ (red). Panel (a) an example of crossing phenomenon obtained by imposing $\sqrt\Delta_+=1$, $\sqrt\Delta_-=0.2$, and $\rho=0.1$. Panel (b) shows the double crossing, obtained by introducing an additional timescale to the previous case by tuning label imbalance. Panel (c) explore the effect of changing $\Delta_-$ while keeping a constant $\Delta_+=1$.
  • ...and 8 more figures

Theorems & Definitions (4)

  • Lemma 3.1
  • Lemma 3.2
  • Theorem 3.3
  • Theorem 3.4