Stochastic gradient descent in high dimensions for multi-spiked tensor PCA
Gérard Ben Arous, Cédric Gerbelot, Vanessa Piccolo
TL;DR
The paper analyzes online stochastic gradient descent (SGD) for multi-spiked tensor PCA, reducing the high-dimensional dynamics to a low-dimensional system of correlations between estimator directions and the true spikes. It establishes sharp sample-complexity thresholds: for p ≥ 3, full spike recovery (up to permutation) occurs with M ≳ log(N) N^{p-2}, and for p = 2, exact recovery with separated SNRs requires M ≳ log(N)^2 N^{(1 − λ_r^2/λ_1^2)/2}, while equal-SNRs in the p = 2 case yield subspace recovery via the eigenstructure of the Gram matrix. The authors develop nonasymptotic bounds and a robust proof strategy based on Neumann-series expansions of the retraction, discrete-time population dynamics, and martingale-concentration arguments, revealing a sequential-elimination recovery mechanism where each spike is uncovered one by one as cross-correlations are suppressed. This work sharpens algorithmic thresholds to match known information-theoretic limits up to logarithmic factors and provides a detailed, tractable analysis framework for multi-index high-dimensional nonconvex optimization problems. The results have implications for understanding gradient-based algorithms in multi-spike tensor estimation and related multi-index models, offering precise conditions under which online SGD efficiently recovers the hidden structure.
Abstract
We study the high-dimensional dynamics of online stochastic gradient descent (SGD) for the multi-spiked tensor model. This multi-index model arises from the tensor principal component analysis (PCA) problem with multiple spikes, where the goal is to estimate $r$ unknown signal vectors within the $N$-dimensional unit sphere through maximum likelihood estimation from noisy observations of a $p$-tensor. We determine the number of samples and the conditions on the signal-to-noise ratios (SNRs) required to efficiently recover the unknown spikes from natural random initializations. We show that full recovery of all spikes is possible provided a number of sample scaling as $N^{p-2}$, matching the algorithmic threshold identified in the rank-one case [Ben Arous, Gheissari, Jagannath 2020, 2021]. Our results are obtained through a detailed analysis of a low-dimensional system that describes the evolution of the correlations between the estimators and the spikes, while controlling the noise in the dynamics. We find that the spikes are recovered sequentially in a process we term "sequential elimination": once a correlation exceeds a critical threshold, all correlations sharing a row or column index become sufficiently small, allowing the next correlation to grow and become macroscopic. The order in which correlations become macroscopic depends on their initial values and the corresponding SNRs, leading to either exact recovery or recovery of a permutation of the spikes. In the matrix case, when $p=2$, if the SNRs are sufficiently separated, we achieve exact recovery of the spikes, whereas equal SNRs lead to recovery of the subspace spanned by them.
