Learning sum of diverse features: computational hardness and efficient gradient-based training for ridge combinations
Kazusato Oko, Yujin Song, Taiji Suzuki, Denny Wu
TL;DR
This work analyzes learning a high-dimensional additive model $f_*(x)=\frac{1}{\sqrt{M}}\sum_{m=1}^M f_m(v_m^\top x)$ with $M$ growing with dimension and diverse near-orthogonal directions. It shows that gradient-based training of a two-layer neural network can efficiently learn a large subset of polynomial target functions with sample complexity $n=\tilde{O}(Md^{p-1})$, leveraging localization of hidden units to task directions and convex training of the second layer (ridge or LASSO). The paper also establishes statistical-query lower bounds (CSQ and full SQ) that imply computational hardness that scales with $M$ and $d$, highlighting a gap between achievable gradient-based learning and SQ-based limits in the large-$M$ regime. The results illuminate how additive structure and task diversification enable efficient representation learning and fine-tuning, while clarifying fundamental limits of SQ methods in large-scale pretraining-like settings. Overall, the work advances understanding of scalable learning when the number of diverse tasks grows with dimension and has implications for transfer and fine-tuning in multi-skill neural networks.
Abstract
We study the computational and sample complexity of learning a target function $f_*:\mathbb{R}^d\to\mathbb{R}$ with additive structure, that is, $f_*(x) = \frac{1}{\sqrt{M}}\sum_{m=1}^M f_m(\langle x, v_m\rangle)$, where $f_1,f_2,...,f_M:\mathbb{R}\to\mathbb{R}$ are nonlinear link functions of single-index models (ridge functions) with diverse and near-orthogonal index features $\{v_m\}_{m=1}^M$, and the number of additive tasks $M$ grows with the dimensionality $M\asymp d^γ$ for $γ\ge 0$. This problem setting is motivated by the classical additive model literature, the recent representation learning theory of two-layer neural network, and large-scale pretraining where the model simultaneously acquires a large number of "skills" that are often localized in distinct parts of the trained network. We prove that a large subset of polynomial $f_*$ can be efficiently learned by gradient descent training of a two-layer neural network, with a polynomial statistical and computational complexity that depends on the number of tasks $M$ and the information exponent of $f_m$, despite the unknown link function and $M$ growing with the dimensionality. We complement this learnability guarantee with computational hardness result by establishing statistical query (SQ) lower bounds for both the correlational SQ and full SQ algorithms.
