Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence
Berfin Şimşek, Amire Bendjeddou, Daniel Hsu
TL;DR
We analyze gradient flow for a decoupled two-layer network trained with correlation loss to learn a high-dimensional multi-index function on Gaussian data, formalizing the dynamics on a reduced subspace and linking fixed points to tensor eigenvectors in the orthogonal case. The results quantify time scales, showing $T=\Theta(d^{p^*/2-1})$ (for $p^*\ge3$) or $T=\Theta(\log d)$ (for $p^*=2$), and prove that a single neuron converges to the nearest index vector under orthogonality, with a complete fixed-point classification in that setting. A saddle-to-minimum transition emerges for equiangular index sets, with a sharp threshold $\beta_c=(p^*-2)/(k+p^*-2)$ determining when the average fixed point becomes a local minimum. Mild overparameterization ($n \approx k\log k$) suffices to recover all index directions with high probability, while constant-factor overparameterization may fail; simulations illustrate both the potential and limits of correlation loss. Overall, the work connects gradient-flow dynamics with tensor decomposition, informing how geometry and overparameterization shape learnability in high dimensions.
Abstract
This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) \!=\! \sum_{j=1}^k \! σ^*(v_j^T x)$ where $v_1, \dots, v_k$ are unit vectors, and $σ^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k\!=\!1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n \! \asymp \! k \log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $ v_i^T v_j \!=\! β\! \geq \! 0$ for all $i \neq j$, we prove the existence of a sharp threshold $β_c \!=\! c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.
