Table of Contents
Fetching ...

Scalable Neural Network Kernels

Arijit Sehanobish, Krzysztof Choromanski, Yunfan Zhao, Avinava Dubey, Valerii Likhosherstov

Abstract

We introduce the concept of scalable neural network kernels (SNNKs), the replacements of regular feedforward layers (FFLs), capable of approximating the latter, but with favorable computational properties. SNNKs effectively disentangle the inputs from the parameters of the neural network in the FFL, only to connect them in the final computation via the dot-product kernel. They are also strictly more expressive, as allowing to model complicated relationships beyond the functions of the dot-products of parameter-input vectors. We also introduce the neural network bundling process that applies SNNKs to compactify deep neural network architectures, resulting in additional compression gains. In its extreme version, it leads to the fully bundled network whose optimal parameters can be expressed via explicit formulae for several loss functions (e.g. mean squared error), opening a possibility to bypass backpropagation. As a by-product of our analysis, we introduce the mechanism of the universal random features (or URFs), applied to instantiate several SNNK variants, and interesting on its own in the context of scalable kernel methods. We provide rigorous theoretical analysis of all these concepts as well as an extensive empirical evaluation, ranging from point-wise kernel estimation to Transformers' fine-tuning with novel adapter layers inspired by SNNKs. Our mechanism provides up to 5x reduction in the number of trainable parameters, while maintaining competitive accuracy.

Scalable Neural Network Kernels

Abstract

We introduce the concept of scalable neural network kernels (SNNKs), the replacements of regular feedforward layers (FFLs), capable of approximating the latter, but with favorable computational properties. SNNKs effectively disentangle the inputs from the parameters of the neural network in the FFL, only to connect them in the final computation via the dot-product kernel. They are also strictly more expressive, as allowing to model complicated relationships beyond the functions of the dot-products of parameter-input vectors. We also introduce the neural network bundling process that applies SNNKs to compactify deep neural network architectures, resulting in additional compression gains. In its extreme version, it leads to the fully bundled network whose optimal parameters can be expressed via explicit formulae for several loss functions (e.g. mean squared error), opening a possibility to bypass backpropagation. As a by-product of our analysis, we introduce the mechanism of the universal random features (or URFs), applied to instantiate several SNNK variants, and interesting on its own in the context of scalable kernel methods. We provide rigorous theoretical analysis of all these concepts as well as an extensive empirical evaluation, ranging from point-wise kernel estimation to Transformers' fine-tuning with novel adapter layers inspired by SNNKs. Our mechanism provides up to 5x reduction in the number of trainable parameters, while maintaining competitive accuracy.
Paper Structure (55 sections, 1 theorem, 27 equations, 13 figures, 8 tables)

This paper contains 55 sections, 1 theorem, 27 equations, 13 figures, 8 tables.

Key Result

Theorem 3.2

The nth-order arc-cosine kernel $\mathrm{K}_{n}:\mathbb{R}^{d} \times \mathbb{R}^{d} \rightarrow \mathbb{R}$ is defined as: $\mathrm{K}_{n}(\mathbf{x},\mathbf{y})=\frac{1}{\pi}\|\mathbf{x}\|_{2}^{n}\|\mathbf{y}\|_{2}^{n}J_{n}(\alpha_{\mathbf{x},\mathbf{y}})$, where $\alpha_{\mathbf{x},\mathbf{y}} \i

Figures (13)

  • Figure 1: Pictorial representation of different NN layers discussed in the paper. Pink arrays represent NN weight matrices and grey ones, Gaussian projections matrices applied in SNNKs. Nonlinear transformations applied in mappings $\Phi$ and $\Psi$ are symbolically represented as functions $g$ and $h$ respectively. Upper left: Regular FFL with activation $f$. Upper right: SNNK applied to a single FFL. Bottom: Bundling process using SNNKs and applied two a deep neural network module.
  • Figure 2: Architecture for (a) SNNK layer (see Section \ref{['sec:motivation']}), (b) SNNK-Adpt layer (c) image fitting (SIREN), MNIST and UCI experiments, (d) SNNK-QPNN model, (e) SNNK-inspired Adapter-ViT layer, (f) SNNK-inspired Adapter-BERT layer. (g,h): The relative error (obtained by averaging over $s=500$ instantiations of the RF-mechanism) made by the RF-based estimator on the particular entry of the output of the: (g) SIREN-FFL and (h) arc-cosine-FFL as a function of the number of random projections $p$ (see: Sec. \ref{['sec:mse']}). The maximum $p$ for (g) is larger than for (h), as (g) in theory produces larger variance per random projection. The corresponding standard deviations are negligible: (g)$5\cdot10^{-8}$,$10^{-12}$, $5\cdot10^{-8}$, $10^{-8}$,$10^{-12}$, $2.5\cdot 10^{-9}$, $10^{-12}$, $5 \cdot 10^{-9}$, $10^{-12}$, $10^{-12}$, $10^{-10}$, $10^{-12}$, (h)$10^{-12}$, $3\cdot10^{-8}$,$3\cdot10^{-8}$, $2\cdot10^{-8}$, $10^{-12}$, $5\cdot10^{-9}$.
  • Figure 3: (1) Left column : Injecting SNNK in a PINN network to approximate the potential energy of the 2-body sytem. Top to bottom : Ground truth potential, Learned potential by QPNN sehanobish2021learning and QPNN-SNNK. QPNN-SNNK can learn the potential function perfectly even using less trainable parameters than the baseline QPNN. (2) Rightmost three column : Siren network on the first row, fitting not only the image, but also the gradients. SNNK on the bottom row produces an accurate approximation of the above.
  • Figure 4: Comparison of trainable parameters between various layers/modules and the drop in replacement NNK layers. Results for CiFar-10, CiFar-100 and ImageNet are for SNNK-Adapter models.
  • Figure 5: Ablation with different number of random features for the ReLU-SNNK-adapter experiments on the GLUE dev set. $AA$ is the reported adaptable adapter numbers in moosavi-etal-2022-adaptable.
  • ...and 8 more figures

Theorems & Definitions (3)

  • Remark 3.1: boundedness
  • Theorem 3.2: arc-cosine kernels; arccos
  • Remark 3.3