The Implicit Bias of Heterogeneity towards Invariance: A Study of Multi-Environment Matrix Sensing
Yang Xu, Yihong Gu, Cong Fang
TL;DR
This work investigates how standard SGD can implicitly induce invariant learning when training on multi-environment data. By analyzing a multi-environment matrix sensing model with an invariant low-rank signal $\mathbf{A}^*$ and environment-specific spurious parts $\mathbf{A}^{(e)}$, the authors show that using environment-homogeneous batches (HeteroSGD) with large step sizes leads to recovery of $\mathbf{A}^*$, while pooling data across environments biases toward $\mathbf{A}^*+\mathbb{E}_e[\mathbf{A}^{(e)}]$. The key mechanism is oscillation in the spurious space induced by heterogeneity and large updates, which suppresses learning of spurious components and yields convergence in two phases (rapid invariant growth, then alignment to the invariant). These results provide a first theoretical demonstration that implicit biases from data heterogeneity can promote invariance and causality in predictions under realistic training regimes. The findings offer a potential explanation for robust and even causal generalization observed in practice when training on heterogeneous data.
Abstract
Models are expected to engage in invariance learning, which involves distinguishing the core relations that remain consistent across varying environments to ensure the predictions are safe, robust and fair. While existing works consider specific algorithms to realize invariance learning, we show that model has the potential to learn invariance through standard training procedures. In other words, this paper studies the implicit bias of Stochastic Gradient Descent (SGD) over heterogeneous data and shows that the implicit bias drives the model learning towards an invariant solution. We call the phenomenon the implicit invariance learning. Specifically, we theoretically investigate the multi-environment low-rank matrix sensing problem where in each environment, the signal comprises (i) a lower-rank invariant part shared across all environments; and (ii) a significantly varying environment-dependent spurious component. The key insight is, through simply employing the large step size large-batch SGD sequentially in each environment without any explicit regularization, the oscillation caused by heterogeneity can provably prevent model learning spurious signals. The model reaches the invariant solution after certain iterations. In contrast, model learned using pooled SGD over all data would simultaneously learn both the invariant and spurious signals. Overall, we unveil another implicit bias that is a result of the symbiosis between the heterogeneity of data and modern algorithms, which is, to the best of our knowledge, first in the literature.
