On Learning Gaussian Multi-index Models with Gradient Flow
Alberto Bietti, Joan Bruna, Loucas Pillaud-Vivien
TL;DR
The paper addresses learning multi-index models F(x)=P_{W*}f*(x) in high dimensions with Gaussian data by proposing a two-timescale gradient-flow framework: fast nonparametric learning of the low-dimensional link function f and slow optimization of the subspace W on the Grassmannian. It establishes global convergence of the population gradient flow, characterizes saddle-to-saddle dynamics via a harmonic (Hermite) decomposition of the target, and derives explicit timescales governed by information exponents; it also analyzes a planted variant where learning can fail due to symmetric target structures. The work introduces the averaging operator A_M as a dimension-free summary that couples f and W through the PSD matrix M=W*^⊤W, and it develops an isotropic Hermite RKHS-based practical implementation with random features to realize the fast-learning step. Together, these results elucidate how incremental subspace learning emerges in high dimensions and offer a bridge to sample-complexity considerations, while highlighting potential limitations of the planted model and suggesting directions for finite-sample and broader-distribution extensions. The insights have implications for understanding progressive feature learning in neural networks and for designing kernel-based two-time-scale algorithms with guaranteed convergence properties in structured, high-dimensional settings.
Abstract
We study gradient flow on the multi-index regression problem for high-dimensional Gaussian data. Multi-index functions consist of a composition of an unknown low-rank linear projection and an arbitrary unknown, low-dimensional link function. As such, they constitute a natural template for feature learning in neural networks. We consider a two-timescale algorithm, whereby the low-dimensional link function is learnt with a non-parametric model infinitely faster than the subspace parametrizing the low-rank projection. By appropriately exploiting the matrix semigroup structure arising over the subspace correlation matrices, we establish global convergence of the resulting Grassmannian population gradient flow dynamics, and provide a quantitative description of its associated `saddle-to-saddle' dynamics. Notably, the timescales associated with each saddle can be explicitly characterized in terms of an appropriate Hermite decomposition of the target link function. In contrast with these positive results, we also show that the related \emph{planted} problem, where the link function is known and fixed, in fact has a rough optimization landscape, in which gradient flow dynamics might get trapped with high probability.
