Table of Contents
Fetching ...

Rapid training of deep neural networks without skip connections or normalization layers using Deep Kernel Shaping

James Martens, Andy Ballard, Guillaume Desjardins, Grzegorz Swirszcz, Valentin Dalibard, Jascha Sohl-Dickstein, Samuel S. Schoenholz

TL;DR

This work introduces Deep Kernel Shaping (DKS), a theoretically grounded method to train very deep neural networks without skip connections or normalization layers by shaping the initialization-time kernel through Q/C maps and NTK perspectives. By transforming activations and carefully initializing weights (Delta and Orthogonal Delta, SUO), DKS enforces a small set of per-subnetwork map conditions (e.g., Q_f(1)=1, Q'_f(1)=1, C_f(0)=0, C'_f(1)≤ζ) and uses Per-Location Normalization to achieve uniform q-values, thereby preserving trainability. The authors derive both local and extended Q/C maps, provide derivative formulas, and connect these to network-wide PKFs and NTK behavior, showing that well-behaved C maps lead to favorable NTK spectra and faster optimization. Empirically, DKS enables skip-free BN-free networks to train on ImageNet and CIFAR-10 at speeds comparable to standard ResNets using strong optimizers (K-FAC/Shampoo), with activation-function transformations broadening applicability to diverse nonlinearities; some generalization gaps remain, but ablations suggest the core components crucial for speed gains. The framework also offers explanatory power for why normalization, skip connections, and particular activations work, and positions DKS as a potential general benchmarking approach for optimizer design and training stability in deep networks.

Abstract

Using an extended and formalized version of the Q/C map analysis of Poole et al. (2016), along with Neural Tangent Kernel theory, we identify the main pathologies present in deep networks that prevent them from training fast and generalizing to unseen data, and show how these can be avoided by carefully controlling the "shape" of the network's initialization-time kernel function. We then develop a method called Deep Kernel Shaping (DKS), which accomplishes this using a combination of precise parameter initialization, activation function transformations, and small architectural tweaks, all of which preserve the model class. In our experiments we show that DKS enables SGD training of residual networks without normalization layers on Imagenet and CIFAR-10 classification tasks at speeds comparable to standard ResNetV2 and Wide-ResNet models, with only a small decrease in generalization performance. And when using K-FAC as the optimizer, we achieve similar results for networks without skip connections. Our results apply for a large variety of activation functions, including those which traditionally perform very badly, such as the logistic sigmoid. In addition to DKS, we contribute a detailed analysis of skip connections, normalization layers, special activation functions like RELU and SELU, and various initialization schemes, explaining their effectiveness as alternative (and ultimately incomplete) ways of "shaping" the network's initialization-time kernel.

Rapid training of deep neural networks without skip connections or normalization layers using Deep Kernel Shaping

TL;DR

This work introduces Deep Kernel Shaping (DKS), a theoretically grounded method to train very deep neural networks without skip connections or normalization layers by shaping the initialization-time kernel through Q/C maps and NTK perspectives. By transforming activations and carefully initializing weights (Delta and Orthogonal Delta, SUO), DKS enforces a small set of per-subnetwork map conditions (e.g., Q_f(1)=1, Q'_f(1)=1, C_f(0)=0, C'_f(1)≤ζ) and uses Per-Location Normalization to achieve uniform q-values, thereby preserving trainability. The authors derive both local and extended Q/C maps, provide derivative formulas, and connect these to network-wide PKFs and NTK behavior, showing that well-behaved C maps lead to favorable NTK spectra and faster optimization. Empirically, DKS enables skip-free BN-free networks to train on ImageNet and CIFAR-10 at speeds comparable to standard ResNets using strong optimizers (K-FAC/Shampoo), with activation-function transformations broadening applicability to diverse nonlinearities; some generalization gaps remain, but ablations suggest the core components crucial for speed gains. The framework also offers explanatory power for why normalization, skip connections, and particular activations work, and positions DKS as a potential general benchmarking approach for optimizer design and training stability in deep networks.

Abstract

Using an extended and formalized version of the Q/C map analysis of Poole et al. (2016), along with Neural Tangent Kernel theory, we identify the main pathologies present in deep networks that prevent them from training fast and generalizing to unseen data, and show how these can be avoided by carefully controlling the "shape" of the network's initialization-time kernel function. We then develop a method called Deep Kernel Shaping (DKS), which accomplishes this using a combination of precise parameter initialization, activation function transformations, and small architectural tweaks, all of which preserve the model class. In our experiments we show that DKS enables SGD training of residual networks without normalization layers on Imagenet and CIFAR-10 classification tasks at speeds comparable to standard ResNetV2 and Wide-ResNet models, with only a small decrease in generalization performance. And when using K-FAC as the optimizer, we achieve similar results for networks without skip connections. Our results apply for a large variety of activation functions, including those which traditionally perform very badly, such as the logistic sigmoid. In addition to DKS, we contribute a detailed analysis of skip connections, normalization layers, special activation functions like RELU and SELU, and various initialization schemes, explaining their effectiveness as alternative (and ultimately incomplete) ways of "shaping" the network's initialization-time kernel.

Paper Structure

This paper contains 212 sections, 28 theorems, 177 equations.

Key Result

Theorem 1

Suppose that $f$ is a network containing only fully-connected combined layers and concatenation operations, the former of which are initialized independently of each other with a standard Gaussian fan-in initialization, and use the same activation function $\phi$. Suppose further that $\phi$ is twic where $D$ is maximum number of nonlinear layers in any input-output path through the network (i.e.

Theorems & Definitions (43)

  • Theorem 1: Adapted from Theorem 2 of daniely2016toward
  • Remark 2
  • Remark 3
  • Remark 4
  • Remark 5
  • Theorem 6: Adapted from Theorem 2 of martens2021validity
  • Remark 7
  • Remark 8
  • Remark 9
  • Proposition 10
  • ...and 33 more