Table of Contents
Fetching ...

Non-Convex Optimization in Federated Learning via Variance Reduction and Adaptive Learning

Dipanwita Thakur, Antonella Guzzo, Giancarlo Fortino, Sajal K. Das

TL;DR

This work addresses non-convex federated optimization under non-IID data by introducing a momentum-based variance-reduction framework with adaptive learning rates for both local updates and global aggregation. The method reduces gradient variance and communication rounds, achieving convergence to an epsilon-stationary point with improved $O(epsilon^{-1})$ communication complexity and strong empirical results on MNIST and CIFAR-10. The combination of momentum-based variance reduction and adaptivity mitigates client drift and accelerates convergence without extra per-client storage or communication. The study highlights practical gains for cross-device FL while noting limitations related to client participation assumptions and suggesting future work on cross-silo extensions.

Abstract

This paper proposes a novel federated algorithm that leverages momentum-based variance reduction with adaptive learning to address non-convex settings across heterogeneous data. We intend to minimize communication and computation overhead, thereby fostering a sustainable federated learning system. We aim to overcome challenges related to gradient variance, which hinders the model's efficiency, and the slow convergence resulting from learning rate adjustments with heterogeneous data. The experimental results on the image classification tasks with heterogeneous data reveal the effectiveness of our suggested algorithms in non-convex settings with an improved communication complexity of $\mathcal{O}(ε^{-1})$ to converge to an $ε$-stationary point - compared to the existing communication complexity $\mathcal{O}(ε^{-2})$ of most prior works. The proposed federated version maintains the trade-off between the convergence rate, number of communication rounds, and test accuracy while mitigating the client drift in heterogeneous settings. The experimental results demonstrate the efficiency of our algorithms in image classification tasks (MNIST, CIFAR-10) with heterogeneous data.

Non-Convex Optimization in Federated Learning via Variance Reduction and Adaptive Learning

TL;DR

This work addresses non-convex federated optimization under non-IID data by introducing a momentum-based variance-reduction framework with adaptive learning rates for both local updates and global aggregation. The method reduces gradient variance and communication rounds, achieving convergence to an epsilon-stationary point with improved communication complexity and strong empirical results on MNIST and CIFAR-10. The combination of momentum-based variance reduction and adaptivity mitigates client drift and accelerates convergence without extra per-client storage or communication. The study highlights practical gains for cross-device FL while noting limitations related to client participation assumptions and suggesting future work on cross-silo extensions.

Abstract

This paper proposes a novel federated algorithm that leverages momentum-based variance reduction with adaptive learning to address non-convex settings across heterogeneous data. We intend to minimize communication and computation overhead, thereby fostering a sustainable federated learning system. We aim to overcome challenges related to gradient variance, which hinders the model's efficiency, and the slow convergence resulting from learning rate adjustments with heterogeneous data. The experimental results on the image classification tasks with heterogeneous data reveal the effectiveness of our suggested algorithms in non-convex settings with an improved communication complexity of to converge to an -stationary point - compared to the existing communication complexity of most prior works. The proposed federated version maintains the trade-off between the convergence rate, number of communication rounds, and test accuracy while mitigating the client drift in heterogeneous settings. The experimental results demonstrate the efficiency of our algorithms in image classification tasks (MNIST, CIFAR-10) with heterogeneous data.

Paper Structure

This paper contains 13 sections, 3 theorems, 10 equations, 12 figures, 2 tables, 2 algorithms.

Key Result

Theorem A.1

Based on the Assumptions asu1, asu2, asu3, and asu4 and for initial batch size $\mathcal{B} =bE$, we set $\eta_t=\frac{k}{(w_t+\sigma^2t)^{1/3}}$, $k=\frac{(bN)^{2/3}\sigma^{2/3}}{L}$. Also set $c = \frac{(8L)^2}{bN} + \frac{\sigma^2}{24LEk^3} = L^2(\frac{64}{bN}+\frac{1}{24(bN)^2E})$ and $w = max\{ where $v \in [0,1]$.

Figures (12)

  • Figure 1: Train Loss vs Epochs with CIFAR-10
  • Figure 2: Train Accuracy vs Epochs with CIFAR-10
  • Figure 3: Train Accuracy vs Communications with CIFAR-10
  • Figure 4: Train Loss vs Communications with CIFAR-10
  • Figure 5: Test Accuracy with CIFAR-10
  • ...and 7 more figures

Theorems & Definitions (5)

  • Theorem A.1
  • Lemma A.2
  • proof
  • Theorem A.3
  • proof