Table of Contents
Fetching ...

On the Benefits of Weight Normalization for Overparameterized Matrix Sensing

Yudong Wei, Liang Zhang, Bingcong Li, Niao He

TL;DR

The paper addresses recovering a low-rank PSD matrix $\mathbf{A}$ from linear measurements by applying generalized weight normalization (WN) to a matrix factorization via polar decomposition. It develops a Riemannian optimization scheme (RGd) on the Stiefel manifold for the direction $\mathbf{X}$ and gradient-based updates for the magnitude $\mathbf{\Theta}$, leading to a two-phase convergence: a saddle-escape phase followed by linear convergence. The main contributions are (i) an exponential improvement in convergence rate over standard gradient methods, (ii) polynomial improvements in iteration and sample complexity with higher overparameterization, and (iii) extensive numerical validation on synthetic and real data, including image reconstruction. The results provide theoretical and empirical evidence that overparameterization, when combined with weight normalization, can be leveraged to accelerate nonconvex matrix sensing and potentially other learning problems.

Abstract

While normalization techniques are widely used in deep learning, their theoretical understanding remains relatively limited. In this work, we establish the benefits of (generalized) weight normalization (WN) applied to the overparameterized matrix sensing problem. We prove that WN with Riemannian optimization achieves linear convergence, yielding an exponential speedup over standard methods that do not use WN. Our analysis further demonstrates that both iteration and sample complexity improve polynomially as the level of overparameterization increases. To the best of our knowledge, this work provides the first characterization of how WN leverages overparameterization for faster convergence in matrix sensing.

On the Benefits of Weight Normalization for Overparameterized Matrix Sensing

TL;DR

The paper addresses recovering a low-rank PSD matrix from linear measurements by applying generalized weight normalization (WN) to a matrix factorization via polar decomposition. It develops a Riemannian optimization scheme (RGd) on the Stiefel manifold for the direction and gradient-based updates for the magnitude , leading to a two-phase convergence: a saddle-escape phase followed by linear convergence. The main contributions are (i) an exponential improvement in convergence rate over standard gradient methods, (ii) polynomial improvements in iteration and sample complexity with higher overparameterization, and (iii) extensive numerical validation on synthetic and real data, including image reconstruction. The results provide theoretical and empirical evidence that overparameterization, when combined with weight normalization, can be leveraged to accelerate nonconvex matrix sensing and potentially other learning problems.

Abstract

While normalization techniques are widely used in deep learning, their theoretical understanding remains relatively limited. In this work, we establish the benefits of (generalized) weight normalization (WN) applied to the overparameterized matrix sensing problem. We prove that WN with Riemannian optimization achieves linear convergence, yielding an exponential speedup over standard methods that do not use WN. Our analysis further demonstrates that both iteration and sample complexity improve polynomially as the level of overparameterization increases. To the best of our knowledge, this work provides the first characterization of how WN leverages overparameterization for faster convergence in matrix sensing.

Paper Structure

This paper contains 39 sections, 26 theorems, 173 equations, 4 figures, 1 table, 1 algorithm.

Key Result

Theorem 3.2

Consider solving the WN-aided sensing problem (problem) initialized with random $\mathbf{X}_0 \in \mathsf{St}(m,r)$ and $\mathbf{\Theta}_0 \in \mathbb{S}^r$ satisfying $\|\mathbf{\Theta}_0\|\leq 2$. Assume that $r_A \le \frac{m}{2}$ and ${\mathcal{M}}(\cdot)$ is $(r\!+\!r_A\!+\!1,\delta)$-RIP with $

Figures (4)

  • Figure 1: The saddle-to-saddle (i.e., sequential learning) behaviors in WN. The x-axis corresponds to the iteration number, and the y-axis follows the subfigure title. (a) Each plateau signifies a saddle; (b) gradient norm at saddles drops by orders; (c) saddles strongly relate to the best rank-$\rho$ approximation of $\mathbf{A}$; (d) sequential learning in the alignment between $\mathbf{X}_t$ and $\mathbf{U}$; (e) sequential learning in the alignment between $\mathbf{X}_t$ and $\mathbf{U}_\perp$; and, (f) sequential pattern in the magnitude variable $\mathbf{\Theta}_t$.
  • Figure 2: Convergence comparison of RGD on WN and GD on (\ref{['Burer_Monteiro']}) under varying problem conditions (squared reconstruction error vs. iteration). (a) WN enables RGD to converge linearly regardless of $\kappa$; (b) with WN, larger $r$ leads to a shorter initial phase and a faster convergence rate.
  • Figure 3: Additional numerical results of WN (squared reconstruction error vs. iteration).
  • Figure 4: The advantages of WN on image reconstruction.

Theorems & Definitions (28)

  • Definition 3.1: Restricted Isometry Property (RIP)
  • Theorem 3.2
  • Lemma 4.1
  • Lemma 4.2
  • Definition A.1
  • Lemma A.2
  • Lemma C.1
  • Lemma C.2
  • Lemma C.3
  • Lemma C.4
  • ...and 18 more