Table of Contents
Fetching ...

Feature-Learning Networks Are Consistent Across Widths At Realistic Scales

Nikhil Vyas, Alexander Atanasov, Blake Bordelon, Depen Morwani, Sabarish Sainathan, Cengiz Pehlevan

TL;DR

The paper investigates how network width affects dynamics under the $\mu$P$ parameterization, testing vision and language models to determine whether realistic widths can be described by an infinite-width feature-learning limit. It demonstrates strong width-consistency in online training across losses, predictions, representations, and dynamical phenomena, with convergence occurring at widths within practical ranges and differing by task complexity. When training becomes offline or tasks are harder, finite-width deviations arise due to initialization-induced variance and a bias of narrower widths; ensembling reduces variance but does not fully recover the infinite-width behavior. A spectral analysis suggests the finite-width bias stems from deformation of eigenfunctions of the ensemble NTK, offering a mechanistic explanation that aligns with after-kernel observations in CIFAR-5m. Overall, the work argues that infinite-width feature-learning models provide a robust framework for understanding realistic networks, while highlighting task-dependent finite-width corrections and the need to consider spectral dynamics in their analysis; the authors also plan to release code to enable reproducibility.

Abstract

We study the effect of width on the dynamics of feature-learning neural networks across a variety of architectures and datasets. Early in training, wide neural networks trained on online data have not only identical loss curves but also agree in their point-wise test predictions throughout training. For simple tasks such as CIFAR-5m this holds throughout training for networks of realistic widths. We also show that structural properties of the models, including internal representations, preactivation distributions, edge of stability phenomena, and large learning rate effects are consistent across large widths. This motivates the hypothesis that phenomena seen in realistic models can be captured by infinite-width, feature-learning limits. For harder tasks (such as ImageNet and language modeling), and later training times, finite-width deviations grow systematically. Two distinct effects cause these deviations across widths. First, the network output has initialization-dependent variance scaling inversely with width, which can be removed by ensembling networks. We observe, however, that ensembles of narrower networks perform worse than a single wide network. We call this the bias of narrower width. We conclude with a spectral perspective on the origin of this finite-width bias.

Feature-Learning Networks Are Consistent Across Widths At Realistic Scales

TL;DR

The paper investigates how network width affects dynamics under the P$ parameterization, testing vision and language models to determine whether realistic widths can be described by an infinite-width feature-learning limit. It demonstrates strong width-consistency in online training across losses, predictions, representations, and dynamical phenomena, with convergence occurring at widths within practical ranges and differing by task complexity. When training becomes offline or tasks are harder, finite-width deviations arise due to initialization-induced variance and a bias of narrower widths; ensembling reduces variance but does not fully recover the infinite-width behavior. A spectral analysis suggests the finite-width bias stems from deformation of eigenfunctions of the ensemble NTK, offering a mechanistic explanation that aligns with after-kernel observations in CIFAR-5m. Overall, the work argues that infinite-width feature-learning models provide a robust framework for understanding realistic networks, while highlighting task-dependent finite-width corrections and the need to consider spectral dynamics in their analysis; the authors also plan to release code to enable reproducibility.

Abstract

We study the effect of width on the dynamics of feature-learning neural networks across a variety of architectures and datasets. Early in training, wide neural networks trained on online data have not only identical loss curves but also agree in their point-wise test predictions throughout training. For simple tasks such as CIFAR-5m this holds throughout training for networks of realistic widths. We also show that structural properties of the models, including internal representations, preactivation distributions, edge of stability phenomena, and large learning rate effects are consistent across large widths. This motivates the hypothesis that phenomena seen in realistic models can be captured by infinite-width, feature-learning limits. For harder tasks (such as ImageNet and language modeling), and later training times, finite-width deviations grow systematically. Two distinct effects cause these deviations across widths. First, the network output has initialization-dependent variance scaling inversely with width, which can be removed by ensembling networks. We observe, however, that ensembles of narrower networks perform worse than a single wide network. We call this the bias of narrower width. We conclude with a spectral perspective on the origin of this finite-width bias.
Paper Structure (41 sections, 17 equations, 19 figures)

This paper contains 41 sections, 17 equations, 19 figures.

Figures (19)

  • Figure 1: Consistency of large width behavior across tasks, architectures, observables. a) Loss curves for Resnets on Cifar-5M in $\mu$P are nearly to identical at large widths (see also Figure \ref{['fig:loss_convergence']}). b) For GPT-2 on the C4 dataset raffel2020exploring the loss curves agree at early times and deviate at late times, but wider networks agree for longer (see also Figure \ref{['fig:loss_convergence']} and appendices for Wikitext-103) c) The values that ResNets put on the correct logit for ImageNet appear to converge as the width grows (see also Figure \ref{['fig:functional_convergence']}). d) The attention matrices for transformers on Wikitext-103 become nearly identical as width increases (for quantitative metrics see Figure \ref{['fig:representation_convergence']}.)
  • Figure 2: In the online learning setting, train loss improves as width grows. For sufficiently wide networks, the training lost is consistent across widths. For Cifar-5m this consistency is observed over all of training. For harder tasks like Imagenet and Wikitext-103, networks of different widths agree up until a width-dependent time-step where narrower networks begin performing worse.
  • Figure 3: The output logits on a fixed test point diplays stable behavior at large enough widths. a) Value of network on correct class logit over time as width is varied for CIFAR-5m. Colored errorbars represent one standard deviation. b) Same plot for Imagenet for a fixed image in the test set c) Same plot for Wikitext-103 for a fixed masked token. Across the board the widest networks behave similarly. Next, we use the widest network as a proxy for the infinite-width limit, and compare the logit predictions of narrower networks against that. d) For CIFAR-5m, the relative root-mean-squared error over the test set of the distance to the value that the widest network puts on the correct logit. e) The same for Imagenet. f) The same for Wikitext-103. We see a striking regularity of networks converging to the widest one as the width grows. In Appendix \ref{['sec:futher_convergence']}, we also compare networks of successive widths and show the the difference shrinks.
  • Figure 4: Analog of the last row of Figure \ref{['fig:functional_convergence']} but comparing networks of successive widths rather than comparing all networks to the widest. Again, we see that as the network width grows, the difference between successive networks shrinks.
  • Figure 5: Learned features are consistent across a large range of widths in realistic tasks. (a) The distribution (over neurons) of preactivation values $h$ in the final block of $E=8$ ResNet18 networks trained on CIFAR-5M. At initialization, the densities are all well approximated by the Gaussian with matching mean and variance (dashed black). After feature learning, the density has shifted and become non-Gaussian (poor match with dashed black), yet is still strikingly consistent across widths. (b) Average (over random init) feature kernels are also consistent across widths. (c) The centered kernel alignment CKA cortes2012algorithmskornblith2019similarity of the width $N$ and width $512$ kernels increases towards $1.0$ as $N$ increases. The $1/\sqrt N$ and $1/N$ trends are plotted for reference. (d) The preactivation histogram for a transformer on Wikitext-103. At initialization the Gaussian of best fit is the standard normal. After training the histograms are still quite Gaussian, with different moments. (e) A variant of Figure \ref{['fig:main_highlights']} (d) at a smaller sequence length. Attention matrices are consistent at large widths. f) Both FFN kernels and attention matrices converge as width grows. The $1/N$ and $1/\sqrt{N}$ trends are plotted for reference.
  • ...and 14 more figures