Table of Contents
Fetching ...

Inversion dynamics of class manifolds in deep learning reveals tradeoffs underlying generalisation

Simone Ciceri, Lorenzo Cassani, Matteo Osella, Pietro Rotondo, Filippo Valle, Marco Gherardi

TL;DR

The inversion is the manifestation of tradeoffs elicited by well-defined and maximally stable elements of the training set called ‘stragglers’, which are particularly influential for generalization.

Abstract

To achieve near-zero training error in a classification problem, the layers of a feed-forward network have to disentangle the manifolds of data points with different labels, to facilitate the discrimination. However, excessive class separation can bring to overfitting since good generalisation requires learning invariant features, which involve some level of entanglement. We report on numerical experiments showing how the optimisation dynamics finds representations that balance these opposing tendencies with a non-monotonic trend. After a fast segregation phase, a slower rearrangement (conserved across data sets and architectures) increases the class entanglement.The training error at the inversion is stable under subsampling, and across network initialisations and optimisers, which characterises it as a property solely of the data structure and (very weakly) of the architecture. The inversion is the manifestation of tradeoffs elicited by well-defined and maximally stable elements of the training set, coined ``stragglers'', particularly influential for generalisation.

Inversion dynamics of class manifolds in deep learning reveals tradeoffs underlying generalisation

TL;DR

The inversion is the manifestation of tradeoffs elicited by well-defined and maximally stable elements of the training set called ‘stragglers’, which are particularly influential for generalization.

Abstract

To achieve near-zero training error in a classification problem, the layers of a feed-forward network have to disentangle the manifolds of data points with different labels, to facilitate the discrimination. However, excessive class separation can bring to overfitting since good generalisation requires learning invariant features, which involve some level of entanglement. We report on numerical experiments showing how the optimisation dynamics finds representations that balance these opposing tendencies with a non-monotonic trend. After a fast segregation phase, a slower rearrangement (conserved across data sets and architectures) increases the class entanglement.The training error at the inversion is stable under subsampling, and across network initialisations and optimisers, which characterises it as a property solely of the data structure and (very weakly) of the architecture. The inversion is the manifestation of tradeoffs elicited by well-defined and maximally stable elements of the training set, coined ``stragglers'', particularly influential for generalisation.
Paper Structure (6 sections, 12 equations, 3 figures)

This paper contains 6 sections, 12 equations, 3 figures.

Figures (3)

  • Figure 1: Non-monotonic learning dynamics. (a) Training disentangles the class manifolds. Scatter plot of the two radii $R_\pm$ (top) and histograms of the distance $D$ (bottom) from 1000 independent runs, at initialization (yellow) and after training (pink). (b) Class manifold dynamics is non-monotonic. Radii and distance (top) and train and test errors (bottom) as functions of training epoch (on the x axis, in log scale); the dashed horizontal lines are the mean values at initialisation; inversion happens in the grey shaded regions; curve widths are 2 standard deviations. (c) Dynamics is robust to sub-sampling. The three metric quantities as functions of training error (only means shown, computed over 20 runs); different curves are obtained by training on non-overlapping subsets of MNIST. (d) Dynamics is similar across optimisers and hyperparameters. Solid lines: Adam (learning rates $0.001$ and $0.005$); dashed lines: GD with weight decay ($\lambda=0.01$ and $0.05$); dotted lines: GD with momentum ($\mu,\eta=0.5,0.5$ and $0.9,0.2$). Curves are averages over 20 runs. (e,f) Randomised labels (pink curves) remove the non-monotonicity.
  • Figure 2: Stragglers shape the dynamics and influence generalisation. (a) Training without stragglers removes the inversion. The blue curve is obtained by training with the full dataset (shaded region corresponds to 2 sigmas); pink curves (indicated by the arrow) are 20 runs with the pruned training set $\mathcal{T}\setminus \mathcal{S}(t_*)$; the variability is due to the different initialisations, which affect both the dynamics and the elements of $\mathcal{S}(t)$; grey curves above and below the pink ones are obtained with pruned training sets $\mathcal{T}\setminus \mathcal{S}(t)$, with $t=100>t_*$ and $t=10<t_*$ respectively. (b) Metric quantities at convergence (y axis) using training sets $\mathcal{T}\setminus\mathcal{S}\left(t\left(\epsilon_\mathrm{tr}\right)\right)$, as functions of $\epsilon_\mathrm{tr}$ (x axis). (c) Removal of stragglers affects the test error at convergence (y axis). The green curves, from bottom to top, are obtained from noisy test sets, obtained by adding white noise, independently to each pixel, with standard deviation $\sigma = 0, 0.5, 0.75, 1., 1.2, 1.5$ respectively (inputs are standardised, see Methods); shaded regions correspond to 2 sigmas. Grey curves are obtained by removing, for each $\epsilon_\mathrm{tr}$, a random set of points, of the same cardinality as $\mathcal{S}\left(t\left(\epsilon_\mathrm{tr}\right)\right)$ (only the two smallest values of $\sigma$ are shown). (d) The inversion point marks a maximally stable set of misclassified points. Pink crosses are z-scores of the stability of the set $\mathcal{S}\left(t\left(\epsilon_\mathrm{tr}\right)\right)$ (y axis; see Methods) under fluctuations in the initialisations, as a function of $\epsilon_\mathrm{tr}$. In all plots, $\mathcal{T}$ contains $P=8192$ elements from MNIST, the architecture is a two-layer network with 20 hidden units.
  • Figure 3: Stragglers across data sets and architectures. (a) The three metric quantities (y axes) as functions of the training error (x axis) for MNIST, KMNIST, and fashion MNIST. (b) Fraction of stragglers has a well-defined large-dataset limit. Dashed lines are fits of Eq. (\ref{['eq:fss']}) to these data. (c), (d) Non-monotonic dynamics of $R_+$ (y axes) in CIFAR-10, as a function of training error in (c) and epochs in (d). (e) The asymptotic (large-dataset) fraction of stragglers (y axis) depends only weakly on the depth, and negligibly on the width, of the architecture. The four groups of boxes correspond to increasing widths from left to right; darker shades of blue correspond to deeper architectures. The curves in (a),(c), and (d) are 20 runs for each data set. Box heights in (b) and (e) correspond to 2 standard deviations. Architectures and parameters: 2 layers with 20 hidden units each in (a) and (b); 8 layers (fully connected) with 20 hidden units each, learning rate $\eta=0.02$, in (c) and (d); 2,4, and 8 layers, each with 10,20,40, and 80 hidden units, $\eta=0.1$, in (e).