Cyclical Weight Consolidation: Towards Solving Catastrophic Forgetting in Serial Federated Learning
Haoyue Song, Jiacheng Wang, Liansheng Wang
TL;DR
This work targets catastrophic forgetting in serial federated learning caused by non-IID data across cyclic site visits. It introduces Cyclical Weight Consolidation (CWC), which regularizes local updates with a consolidation matrix $C^{k,r}$ that stores parameter importance from previous sites and attenuates this memory at the start of each new round. Empirically, CWC reduces the fluctuation seen in traditional serial FL (CWT), improves convergence across MNIST/CIFAR10 with Dirichlet heterogeneity, ISIC2018, and extreme partitioning, and matches or surpasses FedAvg in several non-IID scenarios. The findings underscore CWC's potential to render serial FL competitive with parallel approaches while enabling greater local computation, with an emphasis on balancing memory retention and plasticity through the consolidation factor $\sigma$.
Abstract
Federated Learning (FL) has gained attention for addressing data scarcity and privacy concerns. While parallel FL algorithms like FedAvg exhibit remarkable performance, they face challenges in scenarios with diverse network speeds and concerns about centralized control, especially in multi-institutional collaborations like the medical domain. Serial FL presents an alternative solution, circumventing these challenges by transferring model updates serially between devices in a cyclical manner. Nevertheless, it is deemed inferior to parallel FL in that (1) its performance shows undesirable fluctuations, and (2) it converges to a lower plateau, particularly when dealing with non-IID data. The observed phenomenon is attributed to catastrophic forgetting due to knowledge loss from previous sites. In this paper, to overcome fluctuation and low efficiency in the iterative learning and forgetting process, we introduce cyclical weight consolidation (CWC), a straightforward yet potent approach specifically tailored for serial FL. CWC employs a consolidation matrix to regulate local optimization. This matrix tracks the significance of each parameter on the overall federation throughout the entire training trajectory, preventing abrupt changes in significant weights. During revisitation, to maintain adaptability, old memory undergoes decay to incorporate new information. Our comprehensive evaluations demonstrate that in various non-IID settings, CWC mitigates the fluctuation behavior of the original serial FL approach and enhances the converged performance consistently and significantly. The improved performance is either comparable to or better than the parallel vanilla.
