Table of Contents
Fetching ...

From Kernels to Features: A Multi-Scale Adaptive Theory of Feature Learning

Noa Rubin, Kirsten Fischer, Javed Lindner, David Dahmen, Inbar Seroussi, Zohar Ringel, Michael Krämer, Moritz Helias

TL;DR

This work presents a theoretical framework of multi-scale adaptive feature learning bridging these two views, and finds across regimes that kernel adaptation can be reduced to an effective kernel rescaling when predicting the mean network output in the special case of a linear network.

Abstract

Feature learning in neural networks is crucial for their expressive power and inductive biases, motivating various theoretical approaches. Some approaches describe network behavior after training through a change in kernel scale from initialization, resulting in a generalization power comparable to a Gaussian process. Conversely, in other approaches training results in the adaptation of the kernel to the data, involving directional changes to the kernel. The relationship and respective strengths of these two views have so far remained unresolved. This work presents a theoretical framework of multi-scale adaptive feature learning bridging these two views. Using methods from statistical mechanics, we derive analytical expressions for network output statistics which are valid across scaling regimes and in the continuum between them. A systematic expansion of the network's probability distribution reveals that mean-field scaling requires only a saddle-point approximation, while standard scaling necessitates additional correction terms. Remarkably, we find across regimes that kernel adaptation can be reduced to an effective kernel rescaling when predicting the mean network output in the special case of a linear network. However, for linear and non-linear networks, the multi-scale adaptive approach captures directional feature learning effects, providing richer insights than what could be recovered from a rescaling of the kernel alone.

From Kernels to Features: A Multi-Scale Adaptive Theory of Feature Learning

TL;DR

This work presents a theoretical framework of multi-scale adaptive feature learning bridging these two views, and finds across regimes that kernel adaptation can be reduced to an effective kernel rescaling when predicting the mean network output in the special case of a linear network.

Abstract

Feature learning in neural networks is crucial for their expressive power and inductive biases, motivating various theoretical approaches. Some approaches describe network behavior after training through a change in kernel scale from initialization, resulting in a generalization power comparable to a Gaussian process. Conversely, in other approaches training results in the adaptation of the kernel to the data, involving directional changes to the kernel. The relationship and respective strengths of these two views have so far remained unresolved. This work presents a theoretical framework of multi-scale adaptive feature learning bridging these two views. Using methods from statistical mechanics, we derive analytical expressions for network output statistics which are valid across scaling regimes and in the continuum between them. A systematic expansion of the network's probability distribution reveals that mean-field scaling requires only a saddle-point approximation, while standard scaling necessitates additional correction terms. Remarkably, we find across regimes that kernel adaptation can be reduced to an effective kernel rescaling when predicting the mean network output in the special case of a linear network. However, for linear and non-linear networks, the multi-scale adaptive approach captures directional feature learning effects, providing richer insights than what could be recovered from a rescaling of the kernel alone.

Paper Structure

This paper contains 36 sections, 123 equations, 12 figures, 1 algorithm.

Figures (12)

  • Figure 1: (a) The multi-scale adaptive theory bridges between rescaling and adaptive theories of feature learning. Starting from the distribution of network outputs for trained networks, the choice of order parameter decides whether a rescaling (red) or adaptive (blue) theory is obtained. The choice of order parameter recasts feature learning into either a (i) low-dimensional minimization or (ii)high-dimensional minimization problem. An approximation of the multi-scale adaptive theory in certain limits yields the result of the rescaling approach, but in addition describes (iii) directional aspects of feature learning. (b) Training (solid line) and test errors (dashed line) across scaling regimes for different approaches. While standard scaling (green shaded area) requires a one-loop approximation with fluctuation corrections (Fluct. Corr.), a saddle-point or tree-level approximation (Saddle-Point) is sufficient in mean-field scaling (orange shaded area). We show results for the kernel rescaling theory by Li21_031059 as reference (Rescaling). We here show results for a linearly separable task; for results on MNIST see fig:mnist_transition in app:add_figures. Parameters: $\gamma=1$, $P_{\text{train}}=80$, $N=100$, $D=200$, $\kappa_{0}=1$, $P_{\text{test}}=10^{3}$, $g_{v}=g_{w}=0.5$, $\Delta p=0.1$.
  • Figure 2: (a) Training discrepancies $\langle\Delta\rangle=y-\langle f_{\mathcal{D}}\rangle$ and (b) test discrepancies $\langle\Delta_{\ast}\rangle=y_{\ast}-\langle f_{\ast}\rangle$ on an Ising task in mean-field scaling. We show theoretical values for both NNGP and tree-level against empirical results, where the gray line marks the identity. In contrast to the NNGP, the tree-level approximation accurately matches the empirical values. While we use $\phi=\text{id}$ here, the non-linear case $\phi=\text{erf}$ yields similar results (see fig:nonlinear_scatter in app:add_figures). Parameters: $\gamma=2$, $P_{\text{train}}=80$, $N=100$, $D=200$, $\kappa_{0}=1$, $P_{\text{test}}=10^{3}$, $g_{v}=g_{w}=0.5$, $\Delta p=0.1$.
  • Figure 3: (a) Training discrepancies $\langle\Delta\rangle=y-\langle f_{\mathcal{D}}\rangle$ and (b) test discrepancies $\langle\Delta_{\ast}\rangle=y_{\ast}-\langle f_{\ast}\rangle$ on an Ising task in standard scaling. Upper row: theoretical values for different theories against empirical results; gray line marks the identity. Lower row: difference of theoretical values to the NNGP as a baseline against NNGP predictions, indicating small differences between the different approaches. Results of the kernel rescaling approach by Li21_031059 are shown as reference (LS). Parameters: $\gamma=1$, $P_{\text{train}}=80$, $N=100$, $D=200$, $\kappa_{0}=0.4$, $P_{\text{test}}=10^{3}$, $g_{v}=0.5$, $g_{w}=0.2$, $\Delta p=0.1$.
  • Figure 4: Relative directional feature learning in a teacher student setting as a function of the fluctuation scale $1/\chi$. Both NNGP and rescaling theory fail to capture directional feature learning, while the multi-scale adaptive theory accurately predicts network behavior. Insets show the output distribution in different directions; a detailed version can be found in fig:teacher_student_appendix in the Appendix. Parameters: $P_{\text{train}}=80$, $N=200$, $D=50$, $\kappa_{0}=2$, $g_{v}=0.01,$$g_{w}=2$.
  • Figure 5: $g$-learnability of target components in two-layer erf network, trained on the target given in eq:teacher, where $H_{i}$ learnability is defined according to eq:learnability, with $g(X)=H_{i}(Xw_{*})$. As can be seen in panel (b), the adaptive approach as derived using VGA predicts that the network will begin to learn higher-order components of $y$ at $P\sim\mathcal{O}(D)$, due to the manifestation of directionally dependent feature learning. On the other hand, kernel methods such as the NNGP and the rescaling approach predict that the network output will be linear, and the cubic component of the target would require $P\sim(D^{3})$ training samples. Parameters: $P_{\text{test}}=4.000$, $N=1.000$, $D=30$, $\kappa_{0}=2$, $g_{v}=g_{w}=1,$$\gamma=1,\epsilon=-0.1$.
  • ...and 7 more figures