Efficient Sign-Based Optimization: Accelerating Convergence via Variance Reduction
Wei Jiang, Sifan Yang, Wenhao Yang, Lijun Zhang
TL;DR
This work addresses non-convex stochastic optimization under sign-based updates by introducing Sign-based Stochastic Variance Reduction (SSVR), which combines variance-reduced gradient estimators with sign-based updates. The proposed approach achieves a faster convergence rate of $O(d^{1/2}T^{-1/3})$ for general non-convex objectives and, for finite-sum problems, $O(m^{1/4}d^{1/2}T^{-1/2})$, surpassing prior sign-based methods. In distributed settings, the authors extend this to SSVR-MV with majority vote, obtaining rates of $O(d^{1/2}T^{-1/2} + dn^{-1/2})$ and $O(d^{1/4}T^{-1/4})$ under heterogeneous data distributions. The empirical results on CIFAR-10/100 corroborate the theoretical gains, showing improved convergence and accuracy with 1-bit communication, highlighting practical potential for scalable, communication-efficient distributed learning.
Abstract
Sign stochastic gradient descent (signSGD) is a communication-efficient method that transmits only the sign of stochastic gradients for parameter updating. Existing literature has demonstrated that signSGD can achieve a convergence rate of $\mathcal{O}(d^{1/2}T^{-1/4})$, where $d$ represents the dimension and $T$ is the iteration number. In this paper, we improve this convergence rate to $\mathcal{O}(d^{1/2}T^{-1/3})$ by introducing the Sign-based Stochastic Variance Reduction (SSVR) method, which employs variance reduction estimators to track gradients and leverages their signs to update. For finite-sum problems, our method can be further enhanced to achieve a convergence rate of $\mathcal{O}(m^{1/4}d^{1/2}T^{-1/2})$, where $m$ denotes the number of component functions. Furthermore, we investigate the heterogeneous majority vote in distributed settings and introduce two novel algorithms that attain improved convergence rates of $\mathcal{O}(d^{1/2}T^{-1/2} + dn^{-1/2})$ and $\mathcal{O}(d^{1/4}T^{-1/4})$ respectively, outperforming the previous results of $\mathcal{O}(dT^{-1/4} + dn^{-1/2})$ and $\mathcal{O}(d^{3/8}T^{-1/8})$, where $n$ represents the number of nodes. Numerical experiments across different tasks validate the effectiveness of our proposed methods.
