Table of Contents
Fetching ...

Simplicity Bias of Two-Layer Networks beyond Linearly Separable Data

Nikita Tsoy, Nikola Konstantinov

TL;DR

It is proved that in the early training phases, network features cluster around a few directions that do not depend on the size of the hidden layer, which indicates that features learned in the middle stages of training may be more useful for OOD transfer.

Abstract

Simplicity bias, the propensity of deep models to over-rely on simple features, has been identified as a potential reason for limited out-of-distribution generalization of neural networks (Shah et al., 2020). Despite the important implications, this phenomenon has been theoretically confirmed and characterized only under strong dataset assumptions, such as linear separability (Lyu et al., 2021). In this work, we characterize simplicity bias for general datasets in the context of two-layer neural networks initialized with small weights and trained with gradient flow. Specifically, we prove that in the early training phases, network features cluster around a few directions that do not depend on the size of the hidden layer. Furthermore, for datasets with an XOR-like pattern, we precisely identify the learned features and demonstrate that simplicity bias intensifies during later training stages. These results indicate that features learned in the middle stages of training may be more useful for OOD transfer. We support this hypothesis with experiments on image data.

Simplicity Bias of Two-Layer Networks beyond Linearly Separable Data

TL;DR

It is proved that in the early training phases, network features cluster around a few directions that do not depend on the size of the hidden layer, which indicates that features learned in the middle stages of training may be more useful for OOD transfer.

Abstract

Simplicity bias, the propensity of deep models to over-rely on simple features, has been identified as a potential reason for limited out-of-distribution generalization of neural networks (Shah et al., 2020). Despite the important implications, this phenomenon has been theoretically confirmed and characterized only under strong dataset assumptions, such as linear separability (Lyu et al., 2021). In this work, we characterize simplicity bias for general datasets in the context of two-layer neural networks initialized with small weights and trained with gradient flow. Specifically, we prove that in the early training phases, network features cluster around a few directions that do not depend on the size of the hidden layer. Furthermore, for datasets with an XOR-like pattern, we precisely identify the learned features and demonstrate that simplicity bias intensifies during later training stages. These results indicate that features learned in the middle stages of training may be more useful for OOD transfer. We support this hypothesis with experiments on image data.
Paper Structure (81 sections, 14 theorems, 160 equations, 10 figures, 2 tables)

This paper contains 81 sections, 14 theorems, 160 equations, 10 figures, 2 tables.

Key Result

Theorem 4.1

Assume that $\bm{\theta}$ follows eq:set-gf, $\forall i \: \lVert\bm{x}_i\rVert \le 1$, $d \ge 2$, and $d$ is odd. Then $\exists \kappa^* > 0, P \subseteq [m], (\kappa_j > 0, u^*_j \in \mathbb{R}, \hat{\bm{v}}^*_j \in \mathrm{S}^{d-1})_{j=1}^m$ such that for $\sigma = r^{1 + \kappa^*}$, $T_1 \colone where $R \coloneqq [m] \setminus P$, $\lambda \coloneqq \max_{\hat{\bm{v}} \in \mathrm{S}^{d-1}} \l

Figures (10)

  • Figure 1: Evolution of 4-neuron network initialized at $(u^e_1(0), u^e_2(0), u^e_3(0), u^e_4(0)) = (10^{-4}, -10^{-5}, 10^{-7}, -10^{-6})$. The first column depicts the whole training process. We additionally depict different stages of training process for visual convenience: the second column depicts the first $3584$ training epochs; the third column depicts the epochs from $3584$ to $15872$; the last column depicts training after the $15872$th epoch. Notice that $\alpha_1 \approx \alpha_3$ and $\alpha_2 \approx \alpha_4$, where $\alpha_i$ are the angles between the network features and the cluster directions.
  • Figure 2: Empirical evidence of simplicity bias on XOR-like data. Relative scale in the second row defined as $\lVert\bm{v}_i\rVert^2/\sum_{j=1}^m \lVert\bm{v}_j\rVert^2$.
  • Figure 3: Examples of domino with a car (class $1$ in CIFAR-10) in train (left) and test (right) dataset. Notice that the top MNIST image is an image of $1$ only for the train data.
  • Figure 4: Accuracy and scale of the logistic regression on the validation part of the OOD test set ($y$-axis) vs. the training epoch at which the ResNet features are extracted ($x$-axis).
  • Figure 5: Evolution of 4-neuron network initialized at $(u^e_1(0), u^e_2(0), u^e_3(0), u^e_4(0)) = (10^{-4}, -10^{-5}, 10^{-7}, -10^{-6})$. The first row depicts the whole training process; the second row depicts the first $3584$ training epochs; the third row depicts the epochs from $3584$ to $15872$; the last row depicts training after the $15872$th epoch. Notice that $\alpha_1 \approx \alpha_3$ and $\alpha_2 \approx \alpha_4$.
  • ...and 5 more figures

Theorems & Definitions (26)

  • Theorem 4.1: Proof in \ref{['sec:proof-gen-phase1']}
  • Remark 4.2
  • Definition 4.3
  • Theorem 4.4: Proof in \ref{['sec:proof-gen-phase2']}
  • Remark 4.5
  • Lemma 5.2: Proof in \ref{['sec:proof-spec-g-extr']}
  • Lemma 5.3: Proof in \ref{['sec:proof-spec-init']}
  • Lemma 5.5: Proof in \ref{['sec:proof-spec-extreme']}
  • Theorem 5.6: Theorem 5.6, l21g
  • Proposition 5.7: Proof in \ref{['sec:proof-spec-conv']}
  • ...and 16 more