Initialization Matters: Unraveling the Impact of Pre-Training on Federated Learning
Divyansh Jhunjhunwala, Pranay Sharma, Zheng Xu, Gauri Joshi
TL;DR
This work investigates how pre-training initialization influences Federated Learning performance under data heterogeneity. By analyzing a two-layer ReLU CNN trained with FedAvg, it introduces aligned and misaligned filters and proves a data-heterogeneity–dependent test bound, showing pre-training reduces misalignment and improves generalization in non-IID settings. The analysis employs a signal-noise decomposition of CNN weights and a two-stage convergence argument, corroborated by synthetic experiments and real CNN training on CIFAR-10 and TinyImageNet. The findings offer practical guidance: initialize with pre-trained representations to mitigate heterogeneity effects and optimize local steps to balance communication and learning, with implications for benign overfitting in FL.
Abstract
Initializing with pre-trained models when learning on downstream tasks is becoming standard practice in machine learning. Several recent works explore the benefits of pre-trained initialization in a federated learning (FL) setting, where the downstream training is performed at the edge clients with heterogeneous data distribution. These works show that starting from a pre-trained model can substantially reduce the adverse impact of data heterogeneity on the test performance of a model trained in a federated setting, with no changes to the standard FedAvg training algorithm. In this work, we provide a deeper theoretical understanding of this phenomenon. To do so, we study the class of two-layer convolutional neural networks (CNNs) and provide bounds on the training error convergence and test error of such a network trained with FedAvg. We introduce the notion of aligned and misaligned filters at initialization and show that the data heterogeneity only affects learning on misaligned filters. Starting with a pre-trained model typically results in fewer misaligned filters at initialization, thus producing a lower test error even when the model is trained in a federated setting with data heterogeneity. Experiments in synthetic settings and practical FL training on CNNs verify our theoretical findings.
