Table of Contents
Fetching ...

Projected Forward Gradient-Guided Frank-Wolfe Algorithm via Variance Reduction

M. Rostami, S. S. Kia

TL;DR

This work tackles the high computational and memory cost of gradient evaluation in Frank-Wolfe optimization for constrained DNN training by leveraging Projected Forward Gradient (Projected-FG) as a memory-efficient gradient estimator. Naively applying Projected-FG within FW introduces a non-vanishing convergence error due to gradient noise; to counter this, the authors propose a variance-reduction scheme that averages historical Projected-FG directions, yielding an Averaged PF-FW algorithm. They prove that the variance-reduced method achieves exact convergence for convex objectives and convergence to stationary points for non-convex objectives, with established rates (e.g., $O(1/\sqrt{k})$ in the convex case via suitable parameter tuning and $O(1/\log K)$ for non-convex FW gap). Numerical experiments on MNIST demonstrate substantial memory savings (about 1000 MiB per epoch vs 2150 MiB with full backprop) and competitive accuracy under sparsity constraints, validating both the practical efficiency and theoretical guarantees of the proposed approach.

Abstract

This paper aims to enhance the use of the Frank-Wolfe (FW) algorithm for training deep neural networks. Similar to any gradient-based optimization algorithm, FW suffers from high computational and memory costs when computing gradients for DNNs. This paper introduces the application of the recently proposed projected forward gradient (Projected-FG) method to the FW framework, offering reduced computational cost similar to backpropagation and low memory utilization akin to forward propagation. Our results show that trivial application of the Projected-FG introduces non-vanishing convergence error due to the stochastic noise that the Projected-FG method introduces in the process. This noise results in an non-vanishing variance in the Projected-FG estimated gradient. To address this, we propose a variance reduction approach by aggregating historical Projected-FG directions. We demonstrate rigorously that this approach ensures convergence to the optimal solution for convex functions and to a stationary point for non-convex functions. These convergence properties are validated through a numerical example, showcasing the approach's effectiveness and efficiency.

Projected Forward Gradient-Guided Frank-Wolfe Algorithm via Variance Reduction

TL;DR

This work tackles the high computational and memory cost of gradient evaluation in Frank-Wolfe optimization for constrained DNN training by leveraging Projected Forward Gradient (Projected-FG) as a memory-efficient gradient estimator. Naively applying Projected-FG within FW introduces a non-vanishing convergence error due to gradient noise; to counter this, the authors propose a variance-reduction scheme that averages historical Projected-FG directions, yielding an Averaged PF-FW algorithm. They prove that the variance-reduced method achieves exact convergence for convex objectives and convergence to stationary points for non-convex objectives, with established rates (e.g., in the convex case via suitable parameter tuning and for non-convex FW gap). Numerical experiments on MNIST demonstrate substantial memory savings (about 1000 MiB per epoch vs 2150 MiB with full backprop) and competitive accuracy under sparsity constraints, validating both the practical efficiency and theoretical guarantees of the proposed approach.

Abstract

This paper aims to enhance the use of the Frank-Wolfe (FW) algorithm for training deep neural networks. Similar to any gradient-based optimization algorithm, FW suffers from high computational and memory costs when computing gradients for DNNs. This paper introduces the application of the recently proposed projected forward gradient (Projected-FG) method to the FW framework, offering reduced computational cost similar to backpropagation and low memory utilization akin to forward propagation. Our results show that trivial application of the Projected-FG introduces non-vanishing convergence error due to the stochastic noise that the Projected-FG method introduces in the process. This noise results in an non-vanishing variance in the Projected-FG estimated gradient. To address this, we propose a variance reduction approach by aggregating historical Projected-FG directions. We demonstrate rigorously that this approach ensures convergence to the optimal solution for convex functions and to a stationary point for non-convex functions. These convergence properties are validated through a numerical example, showcasing the approach's effectiveness and efficiency.
Paper Structure (6 sections, 8 theorems, 32 equations, 3 figures, 2 algorithms)

This paper contains 6 sections, 8 theorems, 32 equations, 3 figures, 2 algorithms.

Key Result

Lemma 2.1

The PF-Gradient in Definition def::fgradient is an unbiased estimate of $\boldsymbol{\mathbf{g}} (\boldsymbol{\mathbf{\theta}})$, i.e. $\mathop{\mathrm{\mathbb{E}}}\nolimits[ \hat{\boldsymbol{\mathbf{g}}}(\boldsymbol{\mathbf{\theta}})] =\boldsymbol{\mathbf{g}}(\boldsymbol{\mathbf{\theta}})$. $\Box$

Figures (3)

  • Figure 1: Using forward propagation to compute $\frac{\partial f}{\partial \theta_0}$ in a network with $L-1$ hidden layers, each with $n$ nodes.
  • Figure 2: Projected Forward Gradient
  • Figure 3: Training loss and accuracy with respect to the number of epochs for Algorithm \ref{['FG-FW']} and \ref{['alg::AFG-FW']} over 20 epochs.

Theorems & Definitions (14)

  • Definition 1: Projected Forward Gradient
  • Lemma 2.1: Unbiasedness of the PF-Gradient baydin2022gradients
  • Lemma 2.2: Upper bound on the PF-Gradient baydin2022gradients
  • Lemma 2.3: Variance of the PF-Gradient
  • Lemma 3.1: Trajectories of Algorithm \ref{['FG-FW']} remain in $\mathcal{C}$
  • Theorem 3.1: Convergence bound of Algorithm \ref{['FG-FW']} for convex functions
  • Lemma 4.1: The variance of the averaged projected forward gradient estimator of Algorithm \ref{['alg::AFG-FW']} converges to zero
  • Theorem 4.1: Convergence analysis of Algorithm \ref{['alg::AFG-FW']} for convex functions
  • Theorem 4.2: Convergence analysis of Algorithm \ref{['alg::AFG-FW']} for non-convex functions
  • proof : Proof of Lemma \ref{['lem::lemm-variance-FG']}
  • ...and 4 more