Table of Contents
Fetching ...

Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence

Berfin Şimşek, Amire Bendjeddou, Daniel Hsu

TL;DR

We analyze gradient flow for a decoupled two-layer network trained with correlation loss to learn a high-dimensional multi-index function on Gaussian data, formalizing the dynamics on a reduced subspace and linking fixed points to tensor eigenvectors in the orthogonal case. The results quantify time scales, showing $T=\Theta(d^{p^*/2-1})$ (for $p^*\ge3$) or $T=\Theta(\log d)$ (for $p^*=2$), and prove that a single neuron converges to the nearest index vector under orthogonality, with a complete fixed-point classification in that setting. A saddle-to-minimum transition emerges for equiangular index sets, with a sharp threshold $\beta_c=(p^*-2)/(k+p^*-2)$ determining when the average fixed point becomes a local minimum. Mild overparameterization ($n \approx k\log k$) suffices to recover all index directions with high probability, while constant-factor overparameterization may fail; simulations illustrate both the potential and limits of correlation loss. Overall, the work connects gradient-flow dynamics with tensor decomposition, informing how geometry and overparameterization shape learnability in high dimensions.

Abstract

This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) \!=\! \sum_{j=1}^k \! σ^*(v_j^T x)$ where $v_1, \dots, v_k$ are unit vectors, and $σ^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k\!=\!1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n \! \asymp \! k \log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $ v_i^T v_j \!=\! β\! \geq \! 0$ for all $i \neq j$, we prove the existence of a sharp threshold $β_c \!=\! c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.

Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence

TL;DR

We analyze gradient flow for a decoupled two-layer network trained with correlation loss to learn a high-dimensional multi-index function on Gaussian data, formalizing the dynamics on a reduced subspace and linking fixed points to tensor eigenvectors in the orthogonal case. The results quantify time scales, showing (for ) or (for ), and prove that a single neuron converges to the nearest index vector under orthogonality, with a complete fixed-point classification in that setting. A saddle-to-minimum transition emerges for equiangular index sets, with a sharp threshold determining when the average fixed point becomes a local minimum. Mild overparameterization () suffices to recover all index directions with high probability, while constant-factor overparameterization may fail; simulations illustrate both the potential and limits of correlation loss. Overall, the work connects gradient-flow dynamics with tensor decomposition, informing how geometry and overparameterization shape learnability in high dimensions.

Abstract

This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons where are unit vectors, and lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case (), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When for all , we prove the existence of a sharp threshold at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.

Paper Structure

This paper contains 26 sections, 12 theorems, 106 equations, 6 figures.

Key Result

Lemma 1.1

Assume that $\mathbf w(t)$ solves the ODE prob:single-neuron given an initial condition $\mathbf w \!=\! \mathbf w_0$. Then the vector of dot products $\mathbf u(t) = V^T \mathbf w(t)$ solves the following ODE with an initial condition $\mathbf u_0 = V^T \mathbf w_0$.

Figures (6)

  • Figure 1: Dot products during training; $n=1, d=1000$. We run gradient descent by updating the unit norm vector with the spherical gradient using a learning rate $\eta=0.1$ and normalizing the vector after the update. Here $\sigma^* = h_{p^*}$ and $k=2$. For $p^* \! \in \! \{3,4\}$, the unit vector (neuron) converges in the direction of the nearest index vector at initialization ($j \!=\! 1$). For $p^* \!=\! 2$, a linear combination of the two directions is learned due to rotational symmetry. The maximum dot product reaches a constant value in a longer timescale when the information exponent is bigger.
  • Figure 2: Index vectors forming an equiangular frame with an equal dot product $\beta$, bifurcation diagram; $k=2$, $\sigma^*=h_{p^*}$. The infinite-time behavior of the student vector abruptly changes from monotonic convergence to the nearest direction to convergence to the average of directions (non-monotonically) at a critical value $\beta_f \!\in\! (0,1)$. The red dashed line indicates the saddle-to-minimum threshold $\beta_c$ given in Theorem \ref{['thm:saddle-min']}. Observe the small gap between $\beta_c$ and $\beta_f$.
  • Figure 3: MSE loss helps with neuron allocation; fixed initialization in both figures, $k \!=\! 2$, $n \!=\! 10$. If no neuron at initialization is closest to one of the index vectors, gradient flow fails to find it when using the correlation loss (left panel) whereas the MSE loss fixes this issue thanks to the repulsion between neurons (right panel).
  • Figure 4: Gradient flow trajectories projected to the subspace of orthogonal index vectors for MSE and correlation losses and loss curves; odd activation (a), even activation (b). Initialization and the number of neurons are fixed; $k \!=\! 2$, $n \!=\! 10$, $d \!=\! 1000$. The winning neurons move toward the closest index vectors also for the MSE loss, but the other neurons move non-trivially due to interactions between them. Adding the repulsion term (MSE loss) virtually decreases the time complexity (bottom row), however, the improvement in the time complexity may be only up to a constant factor.
  • Figure 5: Maximum dot product at convergence as the number of index vectors $k$ increases from $1$ to $20$. (left) $\beta=0.2$ and (right) $\beta=0.3$. Observe that increasing the number of index vectors pulls the flow away from preferring one of the index vectors to the average of the index vectors as indicated by the black dashed line.
  • ...and 1 more figures

Theorems & Definitions (25)

  • Lemma 1.1
  • Theorem 2.1: Time complexity
  • proof : Proof sketch
  • Remark 2.1
  • Proposition 2.1: Fixed Points $\leftrightarrow$ Eigenvectors
  • proof
  • Proposition 2.2: Directional Convergence
  • proof : Proof sketch
  • Remark 2.2: $p^*=2$
  • Remark 2.3: $p^*=1$
  • ...and 15 more