Table of Contents
Fetching ...

Learning Hierarchical Polynomials of Multiple Nonlinear Features with Three-Layer Networks

Hengyu Fu, Zihao Wang, Eshaan Nichani, Jason D. Lee

TL;DR

This work studies the learning of hierarchical polynomials of multiple nonlinear features using three-layer neural networks, and shows that a three-layer neural network trained via layerwise gradient descent suffices for efficient feature learning.

Abstract

In deep learning theory, a critical question is to understand how neural networks learn hierarchical features. In this work, we study the learning of hierarchical polynomials of \textit{multiple nonlinear features} using three-layer neural networks. We examine a broad class of functions of the form $f^{\star}=g^{\star}\circ \bp$, where $\bp:\mathbb{R}^{d} \rightarrow \mathbb{R}^{r}$ represents multiple quadratic features with $r \ll d$ and $g^{\star}:\mathbb{R}^{r}\rightarrow \mathbb{R}$ is a polynomial of degree $p$. This can be viewed as a nonlinear generalization of the multi-index model \citep{damian2022neural}, and also an expansion upon previous work that focused only on a single nonlinear feature, i.e. $r = 1$ \citep{nichani2023provable,wang2023learning}. Our primary contribution shows that a three-layer neural network trained via layerwise gradient descent suffices for \begin{itemize}\item complete recovery of the space spanned by the nonlinear features \item efficient learning of the target function $f^{\star}=g^{\star}\circ \bp$ or transfer learning of $f=g\circ \bp$ with a different link function \end{itemize} within $\widetilde{\cO}(d^4)$ samples and polynomial time. For such hierarchical targets, our result substantially improves the sample complexity $Θ(d^{2p})$ of the kernel methods, demonstrating the power of efficient feature learning. It is important to highlight that{ our results leverage novel techniques and thus manage to go beyond all prior settings} such as single-index and multi-index models as well as models depending just on one nonlinear feature, contributing to a more comprehensive understanding of feature learning in deep learning.

Learning Hierarchical Polynomials of Multiple Nonlinear Features with Three-Layer Networks

TL;DR

This work studies the learning of hierarchical polynomials of multiple nonlinear features using three-layer neural networks, and shows that a three-layer neural network trained via layerwise gradient descent suffices for efficient feature learning.

Abstract

In deep learning theory, a critical question is to understand how neural networks learn hierarchical features. In this work, we study the learning of hierarchical polynomials of \textit{multiple nonlinear features} using three-layer neural networks. We examine a broad class of functions of the form , where represents multiple quadratic features with and is a polynomial of degree . This can be viewed as a nonlinear generalization of the multi-index model \citep{damian2022neural}, and also an expansion upon previous work that focused only on a single nonlinear feature, i.e. \citep{nichani2023provable,wang2023learning}. Our primary contribution shows that a three-layer neural network trained via layerwise gradient descent suffices for \begin{itemize}\item complete recovery of the space spanned by the nonlinear features \item efficient learning of the target function or transfer learning of with a different link function \end{itemize} within samples and polynomial time. For such hierarchical targets, our result substantially improves the sample complexity of the kernel methods, demonstrating the power of efficient feature learning. It is important to highlight that{ our results leverage novel techniques and thus manage to go beyond all prior settings} such as single-index and multi-index models as well as models depending just on one nonlinear feature, contributing to a more comprehensive understanding of feature learning in deep learning.

Paper Structure

This paper contains 74 sections, 46 theorems, 286 equations, 4 figures, 1 algorithm.

Key Result

Theorem 1

Suppose $n_1 ,m_2 =\widetilde{\Omega}(d^4)$. Let $\hat{\theta}$ be the output of Algorithm alg:: training algo after $T={\rm poly}(n_1,n_2,m_1,m_2,d)$ steps. Then, there exists a set of hyper-parameters $(\epsilon, \eta_1, \eta_2, \lambda_1, \lambda_2)$ such that, with high probability over the init Moreover, for any other degree $p$ polynomial $g:\mathbb{R}^{r}\rightarrow \mathbb{R}$ with $\left\

Figures (4)

  • Figure 1: The proof idea of Proposition \ref{['prop::reconstructed feature main']}. Block 1 characterizes the constant and linear terms of $g^{\star}$, which is approximately equivalent to the low-order terms $\mathcal{P}_{< 4}(f^{\star})$ by our universality theory and results into biases in the learned weights $\mathbf{h}^{(1)}(\mathbf{x}' )$ after Stage 1. This bias is vanishing with $d \rightarrow \infty$ by our assumptions on $\mathcal{P}_0({f^{\star}})$ and $\mathcal{P}_2(f^{\star})$. Block 2 describes the second-order information of $g^{\star}$ (approximately $\mathcal{P}_4(f^{\star})$), which is of the greatest importance and captured by the quadratic component $c_2Q_2(\cdot)$ in the inner activation $\sigma_2(\cdot)$ and converted into quantities spanned by the $r$ quadratic features $\mathbf{p}$. Block 3 represents the remaining terms of $f^{\star}$, which leads to high-order nuisance in the learned weights, but still dominated by the second term due to Assumption \ref{['assump::target']} when $d$ is large, which enables us to utilize the terms in blue (resulted from Block 2) to reconstruct the features efficiently.
  • Figure 2: For the left panel, Algorithm \ref{['alg:: training algo']} uses two equally sized datasets, while the random feature model uses the full dataset. For the right panel, we conduct transfer learning with $n_1=2^{16}$ pretraining samples and plot the dependence on $n_2$. The figure reports the mean and normalized standard error of the test error using 10,000 fresh samples, based on $5$ independent experimental instances.
  • Figure 3: Test error of Algorithm \ref{['alg:: training algo']} and the naive random feature models with x-axis being the relative sample complexity $(\log_d n)$. We plot the test error of $5$ independent instances for each $d\in\{8,16,32\}$.
  • Figure 4: The linear correlation between the three true features and their corresponding reconstructed features for varying first-stage sample sizes $n_1$. The reconstructed features are standardized to match the variance of the true features. For $i=1,2,3$, the $i$-th scatter plot represents $10,000$ test sample points of $([\mathbf{B}^{\star}\mathbf{h}^{(1)}(\mathbf{x})]_i,\mathbf{x}^{\top}\mathbf{A}_i\mathbf{x})$ for $n_1 \in \{d^2, d^3, d^4\}$, where $d=16$.

Theorems & Definitions (81)

  • Remark 1
  • Remark 2
  • Theorem 1
  • Proposition 1: Reconstruct the feature
  • Lemma 1: Universality of vector-valued functions
  • Proposition 2: Expressivity of the second-stage model
  • Definition 1: high probability events
  • Example 1
  • Lemma 2: Corollary 9.12 in vanprobability
  • Lemma 3: Lemma 9.21 in vanprobability
  • ...and 71 more