Table of Contents
Fetching ...

Fundamental computational limits of weak learnability in high-dimensional multi-index models

Emanuele Troiani, Yatin Dandi, Leonardo Defilippis, Lenka Zdeborová, Bruno Loureiro, Florent Krzakala

TL;DR

This work characterizes the computational limits of weak learnability for Gaussian multi-index models in high dimensions using Bayes-optimal AMP as a baseline for first-order algorithms. It identifies a trivial subspace that can be learned in one AMP step, and, when absent, defines easy directions with a computable threshold $\alpha_c$ above which learning becomes possible (while $\alpha<\alpha_c$ yields a stable no-recovery fixed point). It also reveals hierarchical, grand staircase learning where directions are learned sequentially when coupled to easier ones, and provides concrete examples with phase transitions, including cases where $\alpha_c$ diverges (hard directions like certain parities). The results bridge statistical and computational limits, connect AMP optimality to gradient-based methods, and suggest new avenues for understanding feature learning in deep nets via structured, multi-index targets. The work further offers practical guidance on when efficient learning is possible and how interactions among directions shape the learning trajectory.

Abstract

Multi-index models - functions which only depend on the covariates through a non-linear transformation of their projection on a subspace - are a useful benchmark for investigating feature learning with neural nets. This paper examines the theoretical boundaries of efficient learnability in this hypothesis class, focusing on the minimum sample complexity required for weakly recovering their low-dimensional structure with first-order iterative algorithms, in the high-dimensional regime where the number of samples $n\!=\!αd$ is proportional to the covariate dimension $d$. Our findings unfold in three parts: (i) we identify under which conditions a trivial subspace can be learned with a single step of a first-order algorithm for any $α\!>\!0$; (ii) if the trivial subspace is empty, we provide necessary and sufficient conditions for the existence of an easy subspace where directions that can be learned only above a certain sample complexity $α\!>\!α_c$, where $α_{c}$ marks a computational phase transition. In a limited but interesting set of really hard directions -- akin to the parity problem -- $α_c$ is found to diverge. Finally, (iii) we show that interactions between different directions can result in an intricate hierarchical learning phenomenon, where directions can be learned sequentially when coupled to easier ones. We discuss in detail the grand staircase picture associated to these functions (and contrast it with the original staircase one). Our theory builds on the optimality of approximate message-passing among first-order iterative methods, delineating the fundamental learnability limit across a broad spectrum of algorithms, including neural networks trained with gradient descent, which we discuss in this context.

Fundamental computational limits of weak learnability in high-dimensional multi-index models

TL;DR

This work characterizes the computational limits of weak learnability for Gaussian multi-index models in high dimensions using Bayes-optimal AMP as a baseline for first-order algorithms. It identifies a trivial subspace that can be learned in one AMP step, and, when absent, defines easy directions with a computable threshold above which learning becomes possible (while yields a stable no-recovery fixed point). It also reveals hierarchical, grand staircase learning where directions are learned sequentially when coupled to easier ones, and provides concrete examples with phase transitions, including cases where diverges (hard directions like certain parities). The results bridge statistical and computational limits, connect AMP optimality to gradient-based methods, and suggest new avenues for understanding feature learning in deep nets via structured, multi-index targets. The work further offers practical guidance on when efficient learning is possible and how interactions among directions shape the learning trajectory.

Abstract

Multi-index models - functions which only depend on the covariates through a non-linear transformation of their projection on a subspace - are a useful benchmark for investigating feature learning with neural nets. This paper examines the theoretical boundaries of efficient learnability in this hypothesis class, focusing on the minimum sample complexity required for weakly recovering their low-dimensional structure with first-order iterative algorithms, in the high-dimensional regime where the number of samples is proportional to the covariate dimension . Our findings unfold in three parts: (i) we identify under which conditions a trivial subspace can be learned with a single step of a first-order algorithm for any ; (ii) if the trivial subspace is empty, we provide necessary and sufficient conditions for the existence of an easy subspace where directions that can be learned only above a certain sample complexity , where marks a computational phase transition. In a limited but interesting set of really hard directions -- akin to the parity problem -- is found to diverge. Finally, (iii) we show that interactions between different directions can result in an intricate hierarchical learning phenomenon, where directions can be learned sequentially when coupled to easier ones. We discuss in detail the grand staircase picture associated to these functions (and contrast it with the original staircase one). Our theory builds on the optimality of approximate message-passing among first-order iterative methods, delineating the fundamental learnability limit across a broad spectrum of algorithms, including neural networks trained with gradient descent, which we discuss in this context.
Paper Structure (29 sections, 16 theorems, 135 equations, 4 figures, 1 algorithm)

This paper contains 29 sections, 16 theorems, 135 equations, 4 figures, 1 algorithm.

Key Result

Lemma 2.1

Let $({\boldsymbol{x}}_{i},y_{i})_{i\in[n]}$ denote $n$i.i.d. samples from the multi-index model eq.(def:model). Run AMP from random initialization $\hat{{{\boldsymbol{{W}}}}}^{0}\!\in\!\mathbb{R}^{p\times d}$ with $\hat{{\boldsymbol{w}}}_{k}^{0}\overset{\text{i.i.d}}\sim\mathcal{N}(\mathbf{0},{\bol with ${{\boldsymbol{M}}}^{t}$ satisfying the state evolution equations from initial condition ${{\b

Figures (4)

  • Figure 1: Weak learnability phase transitions for $g(z_{1},z_{2})\!=\!{\rm sign}(z_{1}z_{2})$ (left) and $g(z_{1},z_{2},z_{3})=z_{1}^{2}+{\rm sign}(z_{1}z_{2}z_{3})$ (center and right). Given the permutation symmetry in the models, we display the optimal permutation of the overlap matrix elements reached by AMP. ( Left): Overlaps with the two directions $1/2(M_{11}+M_{22})$ as a function of the sample complexity $\alpha=n/d$, with the phase transition at $\alpha_c = \pi^2/4$. The solid black line is the asymptotic theory from state evolution while crosses are averages over $72$ AMP runs with $d\!=\!500$. ( Center): Overlaps with the first direction $|M_{11}|$ ( blue), and with the second and third one $1/2(M_{22}+M_{33})$ ( red) as a function of the sample complexity $\alpha=n/d$. Solid lines are the state evolution curves \ref{['eq:replica_equation']}, and crosses/dots AMP runs with $d=500$ averaged over $72$ seeds. All other overlaps are zero (black). The two black dots indicate the critical thresholds at $\alpha_1\approx 0.575$ and $\alpha_2=\pi^2/4$. ( Right) Corresponding generalization error as a function of $\alpha=n/d$. The figure can be reproduced using the code provided (See also \ref{['sec:app:numerics']}).
  • Figure 2: Trajectories of a single finite-size run of AMP with $d = 500$ at $\alpha = 4$ for $g(z_1,z_2,z_3) = z_1^2+{\rm sign}(z_1z_2z_3)$. (Left) Evolution of the overlaps. We display $M_{11}$ in blue, $1/2(M_{22}+M_{33})$ in red, and the off-diagonal overlaps in black. (Right) Evolution of the generalisation error.
  • Figure 3: Specialisation as an example of grand staircase with $g\!=\!{\rm sign}(z_{1})\!+\!{\rm sign}(z_{2})\!+\!{\rm sign}(z_{3})$. The direction spanned by $z_1\!+\!z_2\!+\!z_3$ is first learned for any $\alpha>0$. The remaining ones are only learned at the specialisation transition occurring at $\alpha\!\approx\!4.3$. We take care of the symmetries as indicated in S.I. \ref{['sec:app:numerics']}. Crosses denote AMP runs with $d\!=\!100$ averaged over $72$ seeds. The overlap is shown as a function of the sample complexity. Because of the symmetry the elements of $M$ can take one of two values: one on the diagonal and one outside of it. We display $(M_{11}+M_{22}+M_{33})/3$ in blue and $(M_{12}+M_{13}+M_{23})/3$ in red.
  • Figure 4: Trajectories of $16$ AMP runs with random initialisation at $\alpha = 4$ for $g(z_1,z_2) = {\rm sign}(z_1z_2)$. The shaded areas represent the error on the mean. We can see that as the dimension increases the algorithm gets increasingly slower

Theorems & Definitions (29)

  • Definition 1: Gaussian multi-index models
  • Definition 2: Weak subspace recovery
  • Lemma 2.1: State evolution Aubin2018Gerbelot
  • Lemma 3.1: Existence of uninformed fixed point
  • Definition 3: Trivial subspace
  • Theorem 3.2
  • Lemma 3.3
  • Lemma 4.1
  • Definition 4: Easy subspace $E^{\star}$
  • Theorem 4.2
  • ...and 19 more