Table of Contents
Fetching ...

GRAWA: Gradient-based Weighted Averaging for Distributed Training of Deep Learning Models

Tolga Dimlioglu, Anna Choromanska

TL;DR

A new algorithm that periodically pulls workers towards the center variable computed as a weighted average of workers, where the weights are inversely proportional to the gradient norms of the workers such that recovering the flat regions in the optimization landscape is prioritized is proposed.

Abstract

We study distributed training of deep learning models in time-constrained environments. We propose a new algorithm that periodically pulls workers towards the center variable computed as a weighted average of workers, where the weights are inversely proportional to the gradient norms of the workers such that recovering the flat regions in the optimization landscape is prioritized. We develop two asynchronous variants of the proposed algorithm that we call Model-level and Layer-level Gradient-based Weighted Averaging (resp. MGRAWA and LGRAWA), which differ in terms of the weighting scheme that is either done with respect to the entire model or is applied layer-wise. On the theoretical front, we prove the convergence guarantee for the proposed approach in both convex and non-convex settings. We then experimentally demonstrate that our algorithms outperform the competitor methods by achieving faster convergence and recovering better quality and flatter local optima. We also carry out an ablation study to analyze the scalability of the proposed algorithms in more crowded distributed training environments. Finally, we report that our approach requires less frequent communication and fewer distributed updates compared to the state-of-the-art baselines.

GRAWA: Gradient-based Weighted Averaging for Distributed Training of Deep Learning Models

TL;DR

A new algorithm that periodically pulls workers towards the center variable computed as a weighted average of workers, where the weights are inversely proportional to the gradient norms of the workers such that recovering the flat regions in the optimization landscape is prioritized is proposed.

Abstract

We study distributed training of deep learning models in time-constrained environments. We propose a new algorithm that periodically pulls workers towards the center variable computed as a weighted average of workers, where the weights are inversely proportional to the gradient norms of the workers such that recovering the flat regions in the optimization landscape is prioritized. We develop two asynchronous variants of the proposed algorithm that we call Model-level and Layer-level Gradient-based Weighted Averaging (resp. MGRAWA and LGRAWA), which differ in terms of the weighting scheme that is either done with respect to the entire model or is applied layer-wise. On the theoretical front, we prove the convergence guarantee for the proposed approach in both convex and non-convex settings. We then experimentally demonstrate that our algorithms outperform the competitor methods by achieving faster convergence and recovering better quality and flatter local optima. We also carry out an ablation study to analyze the scalability of the proposed algorithms in more crowded distributed training environments. Finally, we report that our approach requires less frequent communication and fewer distributed updates compared to the state-of-the-art baselines.
Paper Structure (36 sections, 8 theorems, 40 equations, 14 figures, 7 tables, 7 algorithms)

This paper contains 36 sections, 8 theorems, 40 equations, 14 figures, 7 tables, 7 algorithms.

Key Result

Theorem 1

Let $x_C = \sum_{i=1}^{M} \beta_i x_i$ and $\beta_i$'s are calculated as in equation prelim:grad_weights. For an $L$-Lipschitz differentiable, real-valued, continuous convex function $f$ with minimizer $x^*$ that also satisfies $||\nabla f(x) || \geq \mu \left (f(x) -f(x^*) \right)$ and is also boun

Figures (14)

  • Figure 1: Contour plot of the loss landscape of the Vincent function and the final optima obtained by running LSGD, EASGD, and GRAWA optimizers.
  • Figure 2: Plot of standardized gradient norms in log scale vs. the generalization gap $(\%)$.
  • Figure 3: Illustration of MGRAWA and LGRAWA weighted averaging schemes in the case of two workers.
  • Figure 4: Test error (%) curves of different distributed training methods in 4 workers and ResNet-20 setting.
  • Figure 5: Test error (%) curves of different distributed training methods in 4 workers and PyramidNet setting.
  • ...and 9 more figures

Theorems & Definitions (14)

  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Theorem 4
  • proof
  • Theorem 5
  • proof
  • Theorem 6
  • proof
  • Theorem 7
  • ...and 4 more