Table of Contents
Fetching ...

Learning sum of diverse features: computational hardness and efficient gradient-based training for ridge combinations

Kazusato Oko, Yujin Song, Taiji Suzuki, Denny Wu

TL;DR

This work analyzes learning a high-dimensional additive model $f_*(x)=\frac{1}{\sqrt{M}}\sum_{m=1}^M f_m(v_m^\top x)$ with $M$ growing with dimension and diverse near-orthogonal directions. It shows that gradient-based training of a two-layer neural network can efficiently learn a large subset of polynomial target functions with sample complexity $n=\tilde{O}(Md^{p-1})$, leveraging localization of hidden units to task directions and convex training of the second layer (ridge or LASSO). The paper also establishes statistical-query lower bounds (CSQ and full SQ) that imply computational hardness that scales with $M$ and $d$, highlighting a gap between achievable gradient-based learning and SQ-based limits in the large-$M$ regime. The results illuminate how additive structure and task diversification enable efficient representation learning and fine-tuning, while clarifying fundamental limits of SQ methods in large-scale pretraining-like settings. Overall, the work advances understanding of scalable learning when the number of diverse tasks grows with dimension and has implications for transfer and fine-tuning in multi-skill neural networks.

Abstract

We study the computational and sample complexity of learning a target function $f_*:\mathbb{R}^d\to\mathbb{R}$ with additive structure, that is, $f_*(x) = \frac{1}{\sqrt{M}}\sum_{m=1}^M f_m(\langle x, v_m\rangle)$, where $f_1,f_2,...,f_M:\mathbb{R}\to\mathbb{R}$ are nonlinear link functions of single-index models (ridge functions) with diverse and near-orthogonal index features $\{v_m\}_{m=1}^M$, and the number of additive tasks $M$ grows with the dimensionality $M\asymp d^γ$ for $γ\ge 0$. This problem setting is motivated by the classical additive model literature, the recent representation learning theory of two-layer neural network, and large-scale pretraining where the model simultaneously acquires a large number of "skills" that are often localized in distinct parts of the trained network. We prove that a large subset of polynomial $f_*$ can be efficiently learned by gradient descent training of a two-layer neural network, with a polynomial statistical and computational complexity that depends on the number of tasks $M$ and the information exponent of $f_m$, despite the unknown link function and $M$ growing with the dimensionality. We complement this learnability guarantee with computational hardness result by establishing statistical query (SQ) lower bounds for both the correlational SQ and full SQ algorithms.

Learning sum of diverse features: computational hardness and efficient gradient-based training for ridge combinations

TL;DR

This work analyzes learning a high-dimensional additive model with growing with dimension and diverse near-orthogonal directions. It shows that gradient-based training of a two-layer neural network can efficiently learn a large subset of polynomial target functions with sample complexity , leveraging localization of hidden units to task directions and convex training of the second layer (ridge or LASSO). The paper also establishes statistical-query lower bounds (CSQ and full SQ) that imply computational hardness that scales with and , highlighting a gap between achievable gradient-based learning and SQ-based limits in the large- regime. The results illuminate how additive structure and task diversification enable efficient representation learning and fine-tuning, while clarifying fundamental limits of SQ methods in large-scale pretraining-like settings. Overall, the work advances understanding of scalable learning when the number of diverse tasks grows with dimension and has implications for transfer and fine-tuning in multi-skill neural networks.

Abstract

We study the computational and sample complexity of learning a target function with additive structure, that is, , where are nonlinear link functions of single-index models (ridge functions) with diverse and near-orthogonal index features , and the number of additive tasks grows with the dimensionality for . This problem setting is motivated by the classical additive model literature, the recent representation learning theory of two-layer neural network, and large-scale pretraining where the model simultaneously acquires a large number of "skills" that are often localized in distinct parts of the trained network. We prove that a large subset of polynomial can be efficiently learned by gradient descent training of a two-layer neural network, with a polynomial statistical and computational complexity that depends on the number of tasks and the information exponent of , despite the unknown link function and growing with the dimensionality. We complement this learnability guarantee with computational hardness result by establishing statistical query (SQ) lower bounds for both the correlational SQ and full SQ algorithms.
Paper Structure (55 sections, 41 theorems, 224 equations, 3 figures, 1 table, 1 algorithm)

This paper contains 55 sections, 41 theorems, 224 equations, 3 figures, 1 table, 1 algorithm.

Key Result

Theorem 1

Under Assumptions assump:additiveassumption:diversity, and assumption:activation, take the number of neurons $J=\tilde{\Theta}(M^{C_p+\frac{1}{2}}\varepsilon^{-1})$, the number of steps for first-layer training $T_{1}=\tilde{\Theta}(Md^{p-1}\lor Md\varepsilon^{-2}\lor M^\frac{5}{2}\varepsilon^{-3})$

Figures (3)

  • Figure 1: Alignment between student neurons and true signals $v_1$ and $v_2$, before (blue) and after (purple) SGD training. Left: neural network optimized by online SGD (Algorithm \ref{['alg:main']}), Right: neural network in the NTK regime.
  • Figure 2: Illustration of the proof for $I=3$. For each $i=1,2,3$, a $2$-dimensional curved surface $\pi_{i}$ on which $A_i(a)=0$ divides the hypercube. First, we take $\lambda_1=\pi_1$. Then, we take the intersection between $\lambda_{1}$ and $\pi_{2}$, which is a curved line and connects one of its boundary on $S_3^+$ and the other in $S_3^-$. Finally, we consider the intersection of $\lambda_{2}$ and $\pi_{3}$. Because $\lambda_2$ connects the points in $S_3^+$ and $S_3^-$ while $\pi_3$ divides the hypercube into the part containing $S_3^+$ and the one containing $S_3^-$, $\lambda_3=\lambda_2\cap \pi_3$ is not an empty set and $A_1(a)=A_2(a)=A_3(a)=0$ holds on $\lambda_3$.
  • Figure 3: Approximation via piecewise constant function. Figure 3(a): shifting the right end of each indicator function proportionally to $f(x_i)$ is approximately equivalent to subtracting $O(f(x))$ from $(h^*_a(x))^k$ in the sense of integral value. Figure 3(b): By considering the staircase function, we can simultaneously modify the contribution of the different exponents of $h^*_a(x)$ to the integral.

Theorems & Definitions (53)

  • Definition 1: Information exponent
  • Remark
  • Remark
  • Theorem 1
  • Lemma 2
  • Lemma 3
  • Lemma 4: Informal
  • Remark
  • Theorem 5: CSQ lower bound
  • Remark
  • ...and 43 more