Table of Contents
Fetching ...

A unified theory of feature learning in RNNs and DNNs

Jan P. Bauer, Kirsten Fischer, Moritz Helias, Agostina Palmigiano

TL;DR

This work tackles why RNNs and DNNs, despite structural similarities, exhibit different functional properties. It develops a unified mean-field kernel theory in the feature-learning (μP) regime, framing training as Bayesian inference over sequences and patterns and deriving a kernel-based description that treats RNNs and DNNs on equal footing. A key finding is a phase transition in endpoint-supervised tasks: below a critical learning signal the kernels of RNNs and DNNs coincide, but above it RNNs develop temporal coherence across timesteps due to weight sharing, with an outlier in the weight spectrum; in sequential tasks, weight sharing provides an inductive bias that enables sample-efficient generalization by interpolating across unsupervised time steps. The framework connects architectural structure to functional biases, offering a principled lens to understand and leverage temporal feature learning in networks.

Abstract

Recurrent and deep neural networks (RNNs/DNNs) are cornerstone architectures in machine learning. Remarkably, RNNs differ from DNNs only by weight sharing, as can be shown through unrolling in time. How does this structural similarity fit with the distinct functional properties these networks exhibit? To address this question, we here develop a unified mean-field theory for RNNs and DNNs in terms of representational kernels, describing fully trained networks in the feature learning ($μ$P) regime. This theory casts training as Bayesian inference over sequences and patterns, directly revealing the functional implications induced by the RNNs' weight sharing. In DNN-typical tasks, we identify a phase transition when the learning signal overcomes the noise due to randomness in the weights: below this threshold, RNNs and DNNs behave identically; above it, only RNNs develop correlated representations across timesteps. For sequential tasks, the RNNs' weight sharing furthermore induces an inductive bias that aids generalization by interpolating unsupervised time steps. Overall, our theory offers a way to connect architectural structure to functional biases.

A unified theory of feature learning in RNNs and DNNs

TL;DR

This work tackles why RNNs and DNNs, despite structural similarities, exhibit different functional properties. It develops a unified mean-field kernel theory in the feature-learning (μP) regime, framing training as Bayesian inference over sequences and patterns and deriving a kernel-based description that treats RNNs and DNNs on equal footing. A key finding is a phase transition in endpoint-supervised tasks: below a critical learning signal the kernels of RNNs and DNNs coincide, but above it RNNs develop temporal coherence across timesteps due to weight sharing, with an outlier in the weight spectrum; in sequential tasks, weight sharing provides an inductive bias that enables sample-efficient generalization by interpolating across unsupervised time steps. The framework connects architectural structure to functional biases, offering a principled lens to understand and leverage temporal feature learning in networks.

Abstract

Recurrent and deep neural networks (RNNs/DNNs) are cornerstone architectures in machine learning. Remarkably, RNNs differ from DNNs only by weight sharing, as can be shown through unrolling in time. How does this structural similarity fit with the distinct functional properties these networks exhibit? To address this question, we here develop a unified mean-field theory for RNNs and DNNs in terms of representational kernels, describing fully trained networks in the feature learning (P) regime. This theory casts training as Bayesian inference over sequences and patterns, directly revealing the functional implications induced by the RNNs' weight sharing. In DNN-typical tasks, we identify a phase transition when the learning signal overcomes the noise due to randomness in the weights: below this threshold, RNNs and DNNs behave identically; above it, only RNNs develop correlated representations across timesteps. For sequential tasks, the RNNs' weight sharing furthermore induces an inductive bias that aids generalization by interpolating unsupervised time steps. Overall, our theory offers a way to connect architectural structure to functional biases.
Paper Structure (45 sections, 88 equations, 6 figures)

This paper contains 45 sections, 88 equations, 6 figures.

Figures (6)

  • Figure 1: Graphical abstract: RNNs resemble DNNs after unrolling-in-time, but differ functionally depending on signal strength and tasks. a) Recurrent neural network (RNN, left) and its unrolling-in-time representation (right): the same recurrent weights $\bm{W}$ are shared across timesteps. b) Endpoint-supervised tasks: Supervision target at the last layer (red) affects hidden layer representation (gray), but only in the RNN and for sufficient signal strength induces a phase transition towards temporal coherence. c)Sequential tasks: The RNNs weight sharing induces a temporally coherent inductive bias that facilitates generalization from supervised (red) to unsupervised timepoints (gray), whereas the DNN exhibits “ regression to the mean” for unsupervised points due to an effectively white prior.
  • Figure 2: Kernels in trained RNNs converge to theory predictions for large network width $N$. We train an RNN as described in eq:SGD with nonlinear activation function $\phi(\circ)=\text{erf}(\tfrac{\sqrt{\pi}}{2}\circ)$ to produce a sinusoidal target sequence $y^{t}=\cos(\frac{2\pi}{T}t)$, $T=10$ in response to a scalar input at time $t=0$, i.e. $x^{t}=\delta^{t0}$. a) Centered kernel alignment (CKA) between the kernel $\text{ \normalfont \includeinkscape[height=1.08]{svg-inkscape/doublephi_svg-raw.pdf} }_{\text{exp.}}(\Theta)=\tfrac{1}{N}\sum_{i}^{N}\phi(h_{i})\phi(h_{i})^{\intercal}$ from explicit weight SGLD experiments at different network widths $N$ and the kernel $\text{ \normalfont \includeinkscape[height=1.08]{svg-inkscape/doublephi_svg-raw.pdf} }_{\text{theory}}=\langle\phi(h)\phi(h)^{\intercal}\rangle_{P(h|y,\boldsymbol{x})}$ predicted by the theory eq:C_MFT for $N\rightarrow\infty$. b) Temporal structure of kernels $\text{ \normalfont \includeinkscape[height=1.08]{svg-inkscape/doublephi_svg-raw.pdf} }_{\text{exp.}}$ for different finite network widths compared to $\text{ \normalfont \includeinkscape[height=1.08]{svg-inkscape/doublephi_svg-raw.pdf} }_{\text{theory}}$ at infinite width $N\rightarrow\infty$. c) Autocorrelation function of the network measured along the anti-diagonal $\text{ \normalfont \includeinkscape[height=1.08]{svg-inkscape/doublephi_svg-raw.pdf} }^{t-t'}$ of the kernel (dashed: theory; full curve: numerics), marked in green in panel b.
  • Figure 3: RNNs and DNNs learn similar spatial representations while only RNNs learn temporal coherence for strong learning signal. a) Binary classification task: $P=4$ pairwise orthogonal inputs $\boldsymbol{x}_{p}\in\mathbb{R}^{D}$ with $D=4$ map to labels $y_{p}\in\{-1,1\}$, as summarized by kernels $\mathbb{Y}_{00}^{tt'}\in\mathbb{R}^{T_{-}\times T_{-}}$, $\mathbb{Y}_{pp'}^{TT}\in\mathbb{R}^{P\times P}$. b-e) Kernel and weight structure of DNNs (panels b, c) and RNNs (panels d, e) trained on the task. Lower-left insets show prediction by the kernel theory. We consider the cases of weak (upper row) or strong learning signal (lower row). From left to right in each cell (panels b-e): temporal kernel $\mathbb{H}_{00}^{tt'}$ for fixed pattern $p=0$, sample kernel $\mathbb{H}_{{pp\mkern-0.5mu\hbox{$\prime$} \mkern5mu}}^{T_{{\raisebox{-0.1ex}{[1.0]{$-$}}}}T_{{\raisebox{-0.1ex}{[1.0]{$-$}}}}}\coloneqq\tfrac{1}{N}\boldsymbol{h}_{{p}}^{T_{{\raisebox{-0.1ex}{[1.0]{$-$}}}}}\cdot\boldsymbol{h}_{{p\mkern-0.5mu\hbox{$\prime$} \mkern5mu}}^{T_{{\raisebox{-0.1ex}{[1.0]{$-$}}}}}$ in last timestep $T_{{\raisebox{-0.1ex}{[1.0]{$-$}}}}\coloneqq T-1=7$, and eigenspectrum in the complex plane of hidden weights $\bm{W}$ (RNN) or $\{\bm{W}^{(t)}\}_{t}$ (DNN, eigenspectra of all $\{\bm{W}^{(t)}\}_{t}$ plotted on same axis), with $N=2048$. Other parameters: $w=v=u=1$, $\kappa=0.1$.
  • Figure 4: Second-order phase transition in linear RNNs of $T=4$ with critical exponent $\frac{1}{2}$. a) Off-diagonal kernel order parameter $\mathbb{H}^{{ T\text{-}1,T\text{-}2}}$ as a function of the control variable $\lambda$. Solid curves: kernel theory for different network depths. Diamonds: empirical kernel from weight SGLD. b) Negative log-probability for the off-diagonal order parameter $\mathbb{H}^{-}$, obtained from maximizing $P(\mathbb{H}|\mathbb{H}^{{ { T\text{-}1,T\text{-}2}}},y,\boldsymbol{x})$ as a function of the signal strength control variable $\lambda$. Network width $N=2048$ (error bars: residual fluctuation at the equilibrium of the update eq:SGD).
  • Figure 5: RNNs have better sample efficiency in sequential tasks due to task-model alignment induced by weight sharing.a) Sequence regression task, with supervision signal at variable numbers of timesteps $t$. Top left: Label kernel $\mathbb{Y}$. Top right: Spectrum of the teacher weights $\bm{W}^{\star}$. Bottom: Sinusoidal output of the teacher. b) Generalization error $\mathcal{L}=\frac{1}{2T}\sum_{t=2}^{T}(y^{t}-f^{t})^{2}$ over all timesteps as a function of the number of supervised timesteps. c-d) Representation $\mathbb{H}$ and output $f^{t}$ underlying differences in generalization. DNN (c) and RNN (d) output for different number of supervision steps (rows). Each cell displays the time-by-time kernel for pattern $p$ (top left), the weight eigenspectra ($\{\bm{W}^{(t)}\}_{t}$ for DNN, $\bm{W}$ for RNN; top right), and on the bottom a comparison of the target function $y^{t}$ (black line) versus the network prediction $f^{t}$ (orange) across supervised timepoints (black markers).
  • ...and 1 more figures