Table of Contents
Fetching ...

Cyclical Weight Consolidation: Towards Solving Catastrophic Forgetting in Serial Federated Learning

Haoyue Song, Jiacheng Wang, Liansheng Wang

TL;DR

This work targets catastrophic forgetting in serial federated learning caused by non-IID data across cyclic site visits. It introduces Cyclical Weight Consolidation (CWC), which regularizes local updates with a consolidation matrix $C^{k,r}$ that stores parameter importance from previous sites and attenuates this memory at the start of each new round. Empirically, CWC reduces the fluctuation seen in traditional serial FL (CWT), improves convergence across MNIST/CIFAR10 with Dirichlet heterogeneity, ISIC2018, and extreme partitioning, and matches or surpasses FedAvg in several non-IID scenarios. The findings underscore CWC's potential to render serial FL competitive with parallel approaches while enabling greater local computation, with an emphasis on balancing memory retention and plasticity through the consolidation factor $\sigma$.

Abstract

Federated Learning (FL) has gained attention for addressing data scarcity and privacy concerns. While parallel FL algorithms like FedAvg exhibit remarkable performance, they face challenges in scenarios with diverse network speeds and concerns about centralized control, especially in multi-institutional collaborations like the medical domain. Serial FL presents an alternative solution, circumventing these challenges by transferring model updates serially between devices in a cyclical manner. Nevertheless, it is deemed inferior to parallel FL in that (1) its performance shows undesirable fluctuations, and (2) it converges to a lower plateau, particularly when dealing with non-IID data. The observed phenomenon is attributed to catastrophic forgetting due to knowledge loss from previous sites. In this paper, to overcome fluctuation and low efficiency in the iterative learning and forgetting process, we introduce cyclical weight consolidation (CWC), a straightforward yet potent approach specifically tailored for serial FL. CWC employs a consolidation matrix to regulate local optimization. This matrix tracks the significance of each parameter on the overall federation throughout the entire training trajectory, preventing abrupt changes in significant weights. During revisitation, to maintain adaptability, old memory undergoes decay to incorporate new information. Our comprehensive evaluations demonstrate that in various non-IID settings, CWC mitigates the fluctuation behavior of the original serial FL approach and enhances the converged performance consistently and significantly. The improved performance is either comparable to or better than the parallel vanilla.

Cyclical Weight Consolidation: Towards Solving Catastrophic Forgetting in Serial Federated Learning

TL;DR

This work targets catastrophic forgetting in serial federated learning caused by non-IID data across cyclic site visits. It introduces Cyclical Weight Consolidation (CWC), which regularizes local updates with a consolidation matrix that stores parameter importance from previous sites and attenuates this memory at the start of each new round. Empirically, CWC reduces the fluctuation seen in traditional serial FL (CWT), improves convergence across MNIST/CIFAR10 with Dirichlet heterogeneity, ISIC2018, and extreme partitioning, and matches or surpasses FedAvg in several non-IID scenarios. The findings underscore CWC's potential to render serial FL competitive with parallel approaches while enabling greater local computation, with an emphasis on balancing memory retention and plasticity through the consolidation factor .

Abstract

Federated Learning (FL) has gained attention for addressing data scarcity and privacy concerns. While parallel FL algorithms like FedAvg exhibit remarkable performance, they face challenges in scenarios with diverse network speeds and concerns about centralized control, especially in multi-institutional collaborations like the medical domain. Serial FL presents an alternative solution, circumventing these challenges by transferring model updates serially between devices in a cyclical manner. Nevertheless, it is deemed inferior to parallel FL in that (1) its performance shows undesirable fluctuations, and (2) it converges to a lower plateau, particularly when dealing with non-IID data. The observed phenomenon is attributed to catastrophic forgetting due to knowledge loss from previous sites. In this paper, to overcome fluctuation and low efficiency in the iterative learning and forgetting process, we introduce cyclical weight consolidation (CWC), a straightforward yet potent approach specifically tailored for serial FL. CWC employs a consolidation matrix to regulate local optimization. This matrix tracks the significance of each parameter on the overall federation throughout the entire training trajectory, preventing abrupt changes in significant weights. During revisitation, to maintain adaptability, old memory undergoes decay to incorporate new information. Our comprehensive evaluations demonstrate that in various non-IID settings, CWC mitigates the fluctuation behavior of the original serial FL approach and enhances the converged performance consistently and significantly. The improved performance is either comparable to or better than the parallel vanilla.
Paper Structure (14 sections, 5 equations, 8 figures, 3 tables)

This paper contains 14 sections, 5 equations, 8 figures, 3 tables.

Figures (8)

  • Figure 1: A comparison between CWT and FedAvg on MNIST reveals the inferiority of CWT in causing undesirable fluctuations in performance and converging to a lower plateau when confronted with non-IID data. We present figures under various Dirichlet concentration parameters $\alpha$ (i.e., 0.01, 0.1, 1.0, and 10.0, arranged from top to bottom and left to right). A smaller $\alpha$ corresponds to greater data heterogeneity.
  • Figure 2: The procedure of our proposed cyclical weight consolidation (CWC). The red circle indicates the parameters being optimized during the optimization process, and the blue circle refers to the parameters being consolidated. The light blue circle with a dashed border indicates the attenuated consolidated parameters. And the rectangular box denotes the entire parameter set of the model.
  • Figure 3: Visualize the overall performance of three algorithms on MNIST and CIFAR10 under three non-IID settings (i.e., $\alpha=0.01$, $\alpha=0.1$, $\alpha=1.0$ from left to right). The curves plot the classification accuracy on the balanced global test set alongside training epochs. It is important to note that serial federated learning is evaluated at the end of each training epoch, while parallel federated learning is evaluated every four training epochs (i.e., one communication round).
  • Figure 4: On the left, the plots display the classification accuracy on the balanced global test set based on the model that has just been updated on each site. In the middle, the plots illustrate the classification accuracy on the site-specific private test set across the entire training trajectory, while on the right, an amplification of the middle's initial stage is presented.
  • Figure 5: Balanced accuracy on the global test set along with training epochs under the setting of $\alpha=0.1$. From left to right, we depict the learning dynamics of FedAvg, CWT, and our CWC.
  • ...and 3 more figures