Rapid training of deep neural networks without skip connections or normalization layers using Deep Kernel Shaping
James Martens, Andy Ballard, Guillaume Desjardins, Grzegorz Swirszcz, Valentin Dalibard, Jascha Sohl-Dickstein, Samuel S. Schoenholz
TL;DR
This work introduces Deep Kernel Shaping (DKS), a theoretically grounded method to train very deep neural networks without skip connections or normalization layers by shaping the initialization-time kernel through Q/C maps and NTK perspectives. By transforming activations and carefully initializing weights (Delta and Orthogonal Delta, SUO), DKS enforces a small set of per-subnetwork map conditions (e.g., Q_f(1)=1, Q'_f(1)=1, C_f(0)=0, C'_f(1)≤ζ) and uses Per-Location Normalization to achieve uniform q-values, thereby preserving trainability. The authors derive both local and extended Q/C maps, provide derivative formulas, and connect these to network-wide PKFs and NTK behavior, showing that well-behaved C maps lead to favorable NTK spectra and faster optimization. Empirically, DKS enables skip-free BN-free networks to train on ImageNet and CIFAR-10 at speeds comparable to standard ResNets using strong optimizers (K-FAC/Shampoo), with activation-function transformations broadening applicability to diverse nonlinearities; some generalization gaps remain, but ablations suggest the core components crucial for speed gains. The framework also offers explanatory power for why normalization, skip connections, and particular activations work, and positions DKS as a potential general benchmarking approach for optimizer design and training stability in deep networks.
Abstract
Using an extended and formalized version of the Q/C map analysis of Poole et al. (2016), along with Neural Tangent Kernel theory, we identify the main pathologies present in deep networks that prevent them from training fast and generalizing to unseen data, and show how these can be avoided by carefully controlling the "shape" of the network's initialization-time kernel function. We then develop a method called Deep Kernel Shaping (DKS), which accomplishes this using a combination of precise parameter initialization, activation function transformations, and small architectural tweaks, all of which preserve the model class. In our experiments we show that DKS enables SGD training of residual networks without normalization layers on Imagenet and CIFAR-10 classification tasks at speeds comparable to standard ResNetV2 and Wide-ResNet models, with only a small decrease in generalization performance. And when using K-FAC as the optimizer, we achieve similar results for networks without skip connections. Our results apply for a large variety of activation functions, including those which traditionally perform very badly, such as the logistic sigmoid. In addition to DKS, we contribute a detailed analysis of skip connections, normalization layers, special activation functions like RELU and SELU, and various initialization schemes, explaining their effectiveness as alternative (and ultimately incomplete) ways of "shaping" the network's initialization-time kernel.
