Table of Contents
Fetching ...

On Learning Gaussian Multi-index Models with Gradient Flow

Alberto Bietti, Joan Bruna, Loucas Pillaud-Vivien

TL;DR

The paper addresses learning multi-index models F(x)=P_{W*}f*(x) in high dimensions with Gaussian data by proposing a two-timescale gradient-flow framework: fast nonparametric learning of the low-dimensional link function f and slow optimization of the subspace W on the Grassmannian. It establishes global convergence of the population gradient flow, characterizes saddle-to-saddle dynamics via a harmonic (Hermite) decomposition of the target, and derives explicit timescales governed by information exponents; it also analyzes a planted variant where learning can fail due to symmetric target structures. The work introduces the averaging operator A_M as a dimension-free summary that couples f and W through the PSD matrix M=W*^⊤W, and it develops an isotropic Hermite RKHS-based practical implementation with random features to realize the fast-learning step. Together, these results elucidate how incremental subspace learning emerges in high dimensions and offer a bridge to sample-complexity considerations, while highlighting potential limitations of the planted model and suggesting directions for finite-sample and broader-distribution extensions. The insights have implications for understanding progressive feature learning in neural networks and for designing kernel-based two-time-scale algorithms with guaranteed convergence properties in structured, high-dimensional settings.

Abstract

We study gradient flow on the multi-index regression problem for high-dimensional Gaussian data. Multi-index functions consist of a composition of an unknown low-rank linear projection and an arbitrary unknown, low-dimensional link function. As such, they constitute a natural template for feature learning in neural networks. We consider a two-timescale algorithm, whereby the low-dimensional link function is learnt with a non-parametric model infinitely faster than the subspace parametrizing the low-rank projection. By appropriately exploiting the matrix semigroup structure arising over the subspace correlation matrices, we establish global convergence of the resulting Grassmannian population gradient flow dynamics, and provide a quantitative description of its associated `saddle-to-saddle' dynamics. Notably, the timescales associated with each saddle can be explicitly characterized in terms of an appropriate Hermite decomposition of the target link function. In contrast with these positive results, we also show that the related \emph{planted} problem, where the link function is known and fixed, in fact has a rough optimization landscape, in which gradient flow dynamics might get trapped with high probability.

On Learning Gaussian Multi-index Models with Gradient Flow

TL;DR

The paper addresses learning multi-index models F(x)=P_{W*}f*(x) in high dimensions with Gaussian data by proposing a two-timescale gradient-flow framework: fast nonparametric learning of the low-dimensional link function f and slow optimization of the subspace W on the Grassmannian. It establishes global convergence of the population gradient flow, characterizes saddle-to-saddle dynamics via a harmonic (Hermite) decomposition of the target, and derives explicit timescales governed by information exponents; it also analyzes a planted variant where learning can fail due to symmetric target structures. The work introduces the averaging operator A_M as a dimension-free summary that couples f and W through the PSD matrix M=W*^⊤W, and it develops an isotropic Hermite RKHS-based practical implementation with random features to realize the fast-learning step. Together, these results elucidate how incremental subspace learning emerges in high dimensions and offer a bridge to sample-complexity considerations, while highlighting potential limitations of the planted model and suggesting directions for finite-sample and broader-distribution extensions. The insights have implications for understanding progressive feature learning in neural networks and for designing kernel-based two-time-scale algorithms with guaranteed convergence properties in structured, high-dimensional settings.

Abstract

We study gradient flow on the multi-index regression problem for high-dimensional Gaussian data. Multi-index functions consist of a composition of an unknown low-rank linear projection and an arbitrary unknown, low-dimensional link function. As such, they constitute a natural template for feature learning in neural networks. We consider a two-timescale algorithm, whereby the low-dimensional link function is learnt with a non-parametric model infinitely faster than the subspace parametrizing the low-rank projection. By appropriately exploiting the matrix semigroup structure arising over the subspace correlation matrices, we establish global convergence of the resulting Grassmannian population gradient flow dynamics, and provide a quantitative description of its associated `saddle-to-saddle' dynamics. Notably, the timescales associated with each saddle can be explicitly characterized in terms of an appropriate Hermite decomposition of the target link function. In contrast with these positive results, we also show that the related \emph{planted} problem, where the link function is known and fixed, in fact has a rough optimization landscape, in which gradient flow dynamics might get trapped with high probability.
Paper Structure (129 sections, 75 theorems, 357 equations, 8 figures)

This paper contains 129 sections, 75 theorems, 357 equations, 8 figures.

Key Result

Proposition 1.2

Let $M = W_*^\top W \in \mathbb{R}^{q \times r}$. We have the following representation:

Figures (8)

  • Figure 1: Cartoon illustrations summarizing the theorems \ref{['thm:critical_points']} and \ref{['thm:coarse-grained']} on the evolution and the geometry of the population loss during the dynamics. Left, the octahedron represents the geometry given by Theorem \ref{['thm:critical_points']} on the critical points, while the trajectory displayed showcases the successive learning of the subspaces $W_1, W_2, W_3$, which lie on the polytope. The time scales of these learning is illustrated symbolically throrugh the learning curve in the right plot as a succession of three saddles with timescales (or plateaus) of increasing orders $\mathcal{O}(d), \mathcal{O}(d^2)$ and $\mathcal{O}(d^3)$, before eventually showing convergence.
  • Figure 2: Illustration of the geometry of the domain $\mathcal{C}_q$ of summary statistics underlying Theorem \ref{['thm:critical_points']}. Its boundary contains matrices $G$ such that either $0= \lambda_{\min}(G)$ or $1= \lambda_{\max}(G)$, but we illustrate here the only the boundaries of the form $\lambda_i(G) \in \{0,1\}$ relevant to the presence of critical points. Each facet should be thought as a Grassmann manifold, e.g. ${\color{cyan} \mathcal{G}(q,1)}$ represents the cyan facet, ${\color{olive} \mathcal{G}(q,2)}$ represents the olive facet, etc. Two particular Grassmannians are represented differently: $\mathcal{G}(q,0) \simeq \{0\}$ and $\mathcal{G}(q,q) \simeq \{\mathbb{R}^q\}$ are naturally drawn as dots. and The critical points of $L$ (illustrated for simplicity as dots), expressed in terms of their canonical summary statistics $G$, will be on the facets of this domain. Moreover, for generic targets (where the condition $\|\mathsf{A}_{\mathrm{Sp}(G_W)}f\| = \|\mathsf{A}_{\mathrm{Jt}(G_W)}f\|$ can be assumed to hold), the inclusion also goes the other way: any critical point of $L$ when restricted to the Grassmann manifold $G \in \mathcal{G}(q,\tau)$ will also be a critical point of $L(G)$ in $\mathcal{C}_q$.
  • Figure 3: Cartoon illustration of the 'saddle-to-saddle' optimization dynamics in the octahedron representation of the critical points for a case where $\tilde{K}=4$. The trajectory of the dynamics selects some of the critical points as precised by Theorem \ref{['thm:coarse-grained']}. These are represented marked as the subspaces $\tilde{W}_k$, for $k \in \llbracket 0,4 \rrbracket$.
  • Figure 4: Cartoon illustration of the evolution of the population loss during the dynamics. The plot represents the loss evolution according to the problem depicted in Example \ref{['ex:true_sequential_cascade_polynomial']}. In this example, after the initialization point, which is a first saddle associated with the timescale $\mathcal{O}(d)$, the loss dynamics has two other plateaus of time scales respectively $\mathcal{O}(d^2)$ and $\mathcal{O}(d^3)$, before eventually converging.
  • Figure 5: Example of target function $f\in L^2_{\gamma_2}$ that leads to failures in the Stiefel Gradient Flow, corresponding to Theorem \ref{['thm:planted_failure']}.
  • ...and 3 more figures

Theorems & Definitions (165)

  • Definition 1.1: Averaging Operator
  • Proposition 1.2
  • Theorem 1.3: Optimization Landscape of fast-slow joint learning (informal version of Theorem \ref{['thm:critical_points']})
  • Theorem 1.4: Saddle-to-Saddle dynamics (informal version of Theorem \ref{['thm:coarse-grained']})
  • Theorem 1.5: Failures of Stiefel Gradient Flow for Planted Problem (informal version of Theorem \ref{['thm:planted_failure']} and \ref{['prop:badmaxima']})
  • Definition 2.1: Tensorized Hermite Basis
  • Lemma 2.2
  • Definition 2.3: Spectral Thresholding
  • Proposition 2.4
  • proof
  • ...and 155 more