Table of Contents
Fetching ...

Stochastic gradient descent in high dimensions for multi-spiked tensor PCA

Gérard Ben Arous, Cédric Gerbelot, Vanessa Piccolo

TL;DR

The paper analyzes online stochastic gradient descent (SGD) for multi-spiked tensor PCA, reducing the high-dimensional dynamics to a low-dimensional system of correlations between estimator directions and the true spikes. It establishes sharp sample-complexity thresholds: for p ≥ 3, full spike recovery (up to permutation) occurs with M ≳ log(N) N^{p-2}, and for p = 2, exact recovery with separated SNRs requires M ≳ log(N)^2 N^{(1 − λ_r^2/λ_1^2)/2}, while equal-SNRs in the p = 2 case yield subspace recovery via the eigenstructure of the Gram matrix. The authors develop nonasymptotic bounds and a robust proof strategy based on Neumann-series expansions of the retraction, discrete-time population dynamics, and martingale-concentration arguments, revealing a sequential-elimination recovery mechanism where each spike is uncovered one by one as cross-correlations are suppressed. This work sharpens algorithmic thresholds to match known information-theoretic limits up to logarithmic factors and provides a detailed, tractable analysis framework for multi-index high-dimensional nonconvex optimization problems. The results have implications for understanding gradient-based algorithms in multi-spike tensor estimation and related multi-index models, offering precise conditions under which online SGD efficiently recovers the hidden structure.

Abstract

We study the high-dimensional dynamics of online stochastic gradient descent (SGD) for the multi-spiked tensor model. This multi-index model arises from the tensor principal component analysis (PCA) problem with multiple spikes, where the goal is to estimate $r$ unknown signal vectors within the $N$-dimensional unit sphere through maximum likelihood estimation from noisy observations of a $p$-tensor. We determine the number of samples and the conditions on the signal-to-noise ratios (SNRs) required to efficiently recover the unknown spikes from natural random initializations. We show that full recovery of all spikes is possible provided a number of sample scaling as $N^{p-2}$, matching the algorithmic threshold identified in the rank-one case [Ben Arous, Gheissari, Jagannath 2020, 2021]. Our results are obtained through a detailed analysis of a low-dimensional system that describes the evolution of the correlations between the estimators and the spikes, while controlling the noise in the dynamics. We find that the spikes are recovered sequentially in a process we term "sequential elimination": once a correlation exceeds a critical threshold, all correlations sharing a row or column index become sufficiently small, allowing the next correlation to grow and become macroscopic. The order in which correlations become macroscopic depends on their initial values and the corresponding SNRs, leading to either exact recovery or recovery of a permutation of the spikes. In the matrix case, when $p=2$, if the SNRs are sufficiently separated, we achieve exact recovery of the spikes, whereas equal SNRs lead to recovery of the subspace spanned by them.

Stochastic gradient descent in high dimensions for multi-spiked tensor PCA

TL;DR

The paper analyzes online stochastic gradient descent (SGD) for multi-spiked tensor PCA, reducing the high-dimensional dynamics to a low-dimensional system of correlations between estimator directions and the true spikes. It establishes sharp sample-complexity thresholds: for p ≥ 3, full spike recovery (up to permutation) occurs with M ≳ log(N) N^{p-2}, and for p = 2, exact recovery with separated SNRs requires M ≳ log(N)^2 N^{(1 − λ_r^2/λ_1^2)/2}, while equal-SNRs in the p = 2 case yield subspace recovery via the eigenstructure of the Gram matrix. The authors develop nonasymptotic bounds and a robust proof strategy based on Neumann-series expansions of the retraction, discrete-time population dynamics, and martingale-concentration arguments, revealing a sequential-elimination recovery mechanism where each spike is uncovered one by one as cross-correlations are suppressed. This work sharpens algorithmic thresholds to match known information-theoretic limits up to logarithmic factors and provides a detailed, tractable analysis framework for multi-index high-dimensional nonconvex optimization problems. The results have implications for understanding gradient-based algorithms in multi-spike tensor estimation and related multi-index models, offering precise conditions under which online SGD efficiently recovers the hidden structure.

Abstract

We study the high-dimensional dynamics of online stochastic gradient descent (SGD) for the multi-spiked tensor model. This multi-index model arises from the tensor principal component analysis (PCA) problem with multiple spikes, where the goal is to estimate unknown signal vectors within the -dimensional unit sphere through maximum likelihood estimation from noisy observations of a -tensor. We determine the number of samples and the conditions on the signal-to-noise ratios (SNRs) required to efficiently recover the unknown spikes from natural random initializations. We show that full recovery of all spikes is possible provided a number of sample scaling as , matching the algorithmic threshold identified in the rank-one case [Ben Arous, Gheissari, Jagannath 2020, 2021]. Our results are obtained through a detailed analysis of a low-dimensional system that describes the evolution of the correlations between the estimators and the spikes, while controlling the noise in the dynamics. We find that the spikes are recovered sequentially in a process we term "sequential elimination": once a correlation exceeds a critical threshold, all correlations sharing a row or column index become sufficiently small, allowing the next correlation to grow and become macroscopic. The order in which correlations become macroscopic depends on their initial values and the corresponding SNRs, leading to either exact recovery or recovery of a permutation of the spikes. In the matrix case, when , if the SNRs are sufficiently separated, we achieve exact recovery of the spikes, whereas equal SNRs lead to recovery of the subspace spanned by them.

Paper Structure

This paper contains 24 sections, 28 theorems, 474 equations, 8 figures.

Key Result

Theorem 1.3

Suppose that $M = M(N)$ grows at most polynomially in $N$ and satisfies $M \gg \log(N) N^{p-2}$. Assume further that $\delta$ satisfies $M^{-1}N^{\frac{p-1}{2}} \ll \delta \ll N^{1/2}(\log(N)M)^{-1/2}$. Then, there exists a permutation $\pi^\ast \in S_r$ such that, for every $\varepsilon > 0$ and ev Furthermore, if there exists $\eta > 1$ such that for some universal constant $C>0$, then for ever

Figures (8)

  • Figure 1: Evolution of the correlations $\{m_{ij}\}_{1 \le i,j \le 2}$ under the population dynamics for $p=3$, $r=2$, $\lambda_1=3$, and $\lambda_2=1$. The SNRs are sufficiently separated to ensure exact recovery of both spikes $\boldsymbol{v}_1$ and $\boldsymbol{v}_2$. Once $m_{11}$ reaches a sufficiently large microscopic threshold, $m_{12}$ and $m_{21}$ begin to decrease, allowing the recovery of $\boldsymbol{v}_2$ after they become negligible in the evolution of $m_{22}$.
  • Figure 2: Evolution of the correlations $\{m_{ij}\}_{1 \le i,j \le 4}$ under the population dynamics for $p=3$, $r=4$, and equal SNRs $\lambda_1= \cdots = \lambda_4 =1$. Since the SNRs are identical, the order in which the correlations become macroscopic is determined by their initial values. The simulation illustrates recovery of a permutation of the four spikes $\boldsymbol{v}_1, \ldots, \boldsymbol{v}_4$ via the sequential elimination phenomenon. The fourth spike is not visible in the plotted time window, as its recovery occurs slightly later.
  • Figure 3: Evolution of the correlations $\{m_{ij}\}_{1 \le i,j \le 3}$ under the online SGD dynamics for $p=3, r=3$, and SNRs $\lambda_1 = 3, \lambda_2 =2, \lambda_3 = 1$. The top panels correspond to $M = 3200$ samples, step size $\delta/N = 0.0003$, and dimension $N=500$. The bottom panels show the corresponding noiseless dynamics, obtained with the same step size and number of steps. This simulation illustrates the recovery of a permutation of the three spikes $\boldsymbol{v}_1, \boldsymbol{v}_2, \boldsymbol{v}_3$ through the sequential elimination phenomenon.
  • Figure 4: Evolution of the correlations $\{m_{ij}\}_{1 \le i,j \le 2}$ under the population dynamics for $p=2$, $r=2$, $\lambda_1=3$, and $\lambda_2=1$. The SNRs are sufficiently separated to ensure exact recovery of both spikes.
  • Figure 5: Evolution of the correlations $\{m_{ij}\}_{1 \le i, j \le 3}$ under the online SGD dynamics for $p=2, r=3$, and SNRs $\lambda_1 = 3, \lambda_2 =2, \lambda_3 = 1$. The top panels correspond to $M = 230$ samples, step size $\delta/N = 0.002$, and dimension $N=500$. The bottom panels show the corresponding noiseless dynamics, obtained with the same step size and number of steps. The simulation shows exact recovery of the three spikes. The decrease in the yellow curve representing the correlation $m_{33}$ is quantitatively controlled in the proof.
  • ...and 3 more figures

Theorems & Definitions (61)

  • Definition 1.1: Correlation
  • Definition 1.2: Recovery of all spikes
  • Theorem 1.3
  • Definition 1.4: Greedy maximum selection
  • Theorem 1.5: Recovery up to a permutation for $p \geq 3$
  • Remark 1.6
  • Definition 1.7: Sequential elimination
  • Theorem 1.8: Theorem \ref{['thm: strong recovery online p>2 asymptotic']} revisited
  • Theorem 1.9: Exact recovery for $p=2$
  • Definition 1.10: Subspace recovery
  • ...and 51 more