Table of Contents
Fetching ...

On the infinite width limit of neural networks with a standard parameterization

Jascha Sohl-Dickstein, Roman Novak, Samuel S. Schoenholz, Jaehoon Lee

TL;DR

The paper critiques two common infinite-width parameterizations (NTK and naive standard) for failing to reflect finite-width training dynamics. It introduces an improved standard parameterization that uses a width-scaling scheme to obtain a well-defined NTK while preserving learning-rate scales and the influence of relative layer widths. Empirical results show the resulting kernels often match NTK accuracy and can outperform it with proper width tuning, aligning more closely with finite-width networks. The authors release code in the Neural Tangents library to enable broader adoption of this approach.

Abstract

There are currently two parameterizations used to derive fixed kernels corresponding to infinite width neural networks, the NTK (Neural Tangent Kernel) parameterization and the naive standard parameterization. However, the extrapolation of both of these parameterizations to infinite width is problematic. The standard parameterization leads to a divergent neural tangent kernel while the NTK parameterization fails to capture crucial aspects of finite width networks such as: the dependence of training dynamics on relative layer widths, the relative training dynamics of weights and biases, and overall learning rate scale. Here we propose an improved extrapolation of the standard parameterization that preserves all of these properties as width is taken to infinity and yields a well-defined neural tangent kernel. We show experimentally that the resulting kernels typically achieve similar accuracy to those resulting from an NTK parameterization, but with better correspondence to the parameterization of typical finite width networks. Additionally, with careful tuning of width parameters, the improved standard parameterization kernels can outperform those stemming from an NTK parameterization. We release code implementing this improved standard parameterization as part of the Neural Tangents library at https://github.com/google/neural-tangents.

On the infinite width limit of neural networks with a standard parameterization

TL;DR

The paper critiques two common infinite-width parameterizations (NTK and naive standard) for failing to reflect finite-width training dynamics. It introduces an improved standard parameterization that uses a width-scaling scheme to obtain a well-defined NTK while preserving learning-rate scales and the influence of relative layer widths. Empirical results show the resulting kernels often match NTK accuracy and can outperform it with proper width tuning, aligning more closely with finite-width networks. The authors release code in the Neural Tangents library to enable broader adoption of this approach.

Abstract

There are currently two parameterizations used to derive fixed kernels corresponding to infinite width neural networks, the NTK (Neural Tangent Kernel) parameterization and the naive standard parameterization. However, the extrapolation of both of these parameterizations to infinite width is problematic. The standard parameterization leads to a divergent neural tangent kernel while the NTK parameterization fails to capture crucial aspects of finite width networks such as: the dependence of training dynamics on relative layer widths, the relative training dynamics of weights and biases, and overall learning rate scale. Here we propose an improved extrapolation of the standard parameterization that preserves all of these properties as width is taken to infinity and yields a well-defined neural tangent kernel. We show experimentally that the resulting kernels typically achieve similar accuracy to those resulting from an NTK parameterization, but with better correspondence to the parameterization of typical finite width networks. Additionally, with careful tuning of width parameters, the improved standard parameterization kernels can outperform those stemming from an NTK parameterization. We release code implementing this improved standard parameterization as part of the Neural Tangents library at https://github.com/google/neural-tangents.

Paper Structure

This paper contains 4 sections, 3 equations, 3 figures, 2 tables.

Figures (3)

  • Figure 1: Infinite width networks with various architectures achieve similar error when using the improved standard parameterization or the NTK parameterization, while the improved standard parameterization better matches properties of typical finite width networks. Each point compares the neural tangent kernel prediction error for the same architecture on CIFAR-10, but using NTK (x-axis) or improved standard (y-axis) parameterization. (Upper) Each point corresponds to varying training set size ($\{80, 160, 400, 800, 2000, 4000, 8000\}$), depth ($\{1, 2, 4, 8, 16\}$ for FC / Conv, fixed number of block of 4 for WRN) and widths ($\{2^k| k=0,...,13\}$ for FC / Conv and widening factor $\{2^k|k=-4,...2\}\cup \{10, 16, 64, 256\}$ for WRN). FC is fully connected network with constant hidden width and Conv-Vec / GAP correspond to constant channel convolutional neural networks without / with global average pooling. WRN-LN is Wide Residual Network with four residual blocks and Batch Normalization layer replaced with Layer Normalization. (Lower) Each layer width of fully connected architecture are randomly sampled from $2^k$ with $k\in\{3,..., 13\}$.
  • Figure 2: For fully connected networks, the neural tangent kernel prediction for the improved standard parameterization can outperform the NTK parameterization, especially when the layer widths $N^l$ used in the standard parameterization are tuned. Experiments are performed on the CIFAR-10 dataset with networks corresponding to 5 hidden layers.
  • Figure 3: SGD trained finite width neural networks perform similarly when using the standard parameterization or the NTK parameterization. For all experiments, the network was trained with an MSE loss on the full CIFAR-10 dataset (45k/5k/10k split). Each point in FC corresponds to varying width $\{2^k| k=0,...,12\}$, and each point in Conv-VEC and Conv-GAP corresponds to varying number of channels {8, 11, 16, 23, 32, 45, 64, 90, 128, 181, 256, 362, 512}. All networks are ReLU networks with $\sigma_w^2 = 2.0, \sigma_b^2=0.0$. They were trained with vanilla SGD without L2 regularization or data augmentation. Constant learning rate was grid searched over 20 log spaced values within [0.01, 100]. For standard parameterization learning rate is divided by $\max(N^l)$. FC networks were trained with batch size 1024 for 3,000 epochs whereas Conv networks were trained with batch size 256 for 10,000 epochs.