Table of Contents
Fetching ...

Initialization Matters: Unraveling the Impact of Pre-Training on Federated Learning

Divyansh Jhunjhunwala, Pranay Sharma, Zheng Xu, Gauri Joshi

TL;DR

This work investigates how pre-training initialization influences Federated Learning performance under data heterogeneity. By analyzing a two-layer ReLU CNN trained with FedAvg, it introduces aligned and misaligned filters and proves a data-heterogeneity–dependent test bound, showing pre-training reduces misalignment and improves generalization in non-IID settings. The analysis employs a signal-noise decomposition of CNN weights and a two-stage convergence argument, corroborated by synthetic experiments and real CNN training on CIFAR-10 and TinyImageNet. The findings offer practical guidance: initialize with pre-trained representations to mitigate heterogeneity effects and optimize local steps to balance communication and learning, with implications for benign overfitting in FL.

Abstract

Initializing with pre-trained models when learning on downstream tasks is becoming standard practice in machine learning. Several recent works explore the benefits of pre-trained initialization in a federated learning (FL) setting, where the downstream training is performed at the edge clients with heterogeneous data distribution. These works show that starting from a pre-trained model can substantially reduce the adverse impact of data heterogeneity on the test performance of a model trained in a federated setting, with no changes to the standard FedAvg training algorithm. In this work, we provide a deeper theoretical understanding of this phenomenon. To do so, we study the class of two-layer convolutional neural networks (CNNs) and provide bounds on the training error convergence and test error of such a network trained with FedAvg. We introduce the notion of aligned and misaligned filters at initialization and show that the data heterogeneity only affects learning on misaligned filters. Starting with a pre-trained model typically results in fewer misaligned filters at initialization, thus producing a lower test error even when the model is trained in a federated setting with data heterogeneity. Experiments in synthetic settings and practical FL training on CNNs verify our theoretical findings.

Initialization Matters: Unraveling the Impact of Pre-Training on Federated Learning

TL;DR

This work investigates how pre-training initialization influences Federated Learning performance under data heterogeneity. By analyzing a two-layer ReLU CNN trained with FedAvg, it introduces aligned and misaligned filters and proves a data-heterogeneity–dependent test bound, showing pre-training reduces misalignment and improves generalization in non-IID settings. The analysis employs a signal-noise decomposition of CNN weights and a two-stage convergence argument, corroborated by synthetic experiments and real CNN training on CIFAR-10 and TinyImageNet. The findings offer practical guidance: initialize with pre-trained representations to mitigate heterogeneity effects and optimize local steps to balance communication and learning, with implications for benign overfitting in FL.

Abstract

Initializing with pre-trained models when learning on downstream tasks is becoming standard practice in machine learning. Several recent works explore the benefits of pre-trained initialization in a federated learning (FL) setting, where the downstream training is performed at the edge clients with heterogeneous data distribution. These works show that starting from a pre-trained model can substantially reduce the adverse impact of data heterogeneity on the test performance of a model trained in a federated setting, with no changes to the standard FedAvg training algorithm. In this work, we provide a deeper theoretical understanding of this phenomenon. To do so, we study the class of two-layer convolutional neural networks (CNNs) and provide bounds on the training error convergence and test error of such a network trained with FedAvg. We introduce the notion of aligned and misaligned filters at initialization and show that the data heterogeneity only affects learning on misaligned filters. Starting with a pre-trained model typically results in fewer misaligned filters at initialization, thus producing a lower test error even when the model is trained in a federated setting with data heterogeneity. Experiments in synthetic settings and practical FL training on CNNs verify our theoretical findings.

Paper Structure

This paper contains 53 sections, 43 theorems, 183 equations, 6 figures, 1 table.

Key Result

Proposition 1

Let $\{{\bf w}_{j,r}^{(t)}\}$, for $j \in \{\pm 1\}$ and $r \in [m]$, be the global CNN filter weights in round $t$. Then there exist unique coefficients $\Gamma_{j,r}^{(t)} \geq 0$ and $\{ P_{j,r,k,i}^{(t)} \}_{k,i}$ such that where $k \in [K]$ and $i \in [N]$ denote the client and sample index respectively.

Figures (6)

  • Figure 1: Test accuracy ($\%$) on CIFAR10 with SqueezeNet model iandola2016squeezenet under random and pretrained initializations for FL and centralized training. Pre-training benefits FL more than centralized setting and significantly reduces the gap between IID and non-IID FL model performance.
  • Figure 2: Empirical results on synthetic dataset to verify the upper bound on test error in \ref{['thm:test_error']}. We fix the training error $\epsilon = 0.1$. \ref{['subfig:filter_alignment']}: Test error increases as we increase the number of misaligned filters, with much larger rate of increase in the non-IID setting. Figures \ref{['subfig:local_steps']} and \ref{['subfig:heterogeneity']}: Test error increases with local steps and heterogeneity when $m/2$ filters are misaligned at initialization, remains constant when all the filters are aligned.
  • Figure 3: Signal learning and noise memorization for our CNN model in the IID $(h = 1/2)$ and NonIID $(h = 0)$ setting after $1$ round. Figures \ref{['subfig:signal_iid']}, \ref{['subfig:signal_noniid']}: In the IID setting signal learning coefficients are similar for all the filters and increase with the number of local steps $\tau$ but in the NonIID setting they saturate (\ref{['lem:main_paper_lemma_signal_growth']}) for misaligned filters ($r = 1,2,4,5$). Figures \ref{['subfig:noise_iid']}, \ref{['subfig:noise_noniid']}: Noise memorization is similar for all filters in both settings and grows with $\tau$\ref{['lem:main_paper_noise_growth']}. Figures \ref{['subfig:snr_iid']}, \ref{['subfig:snr_noniid']}: in the IID setting, the ratio of signal learning to noise memorization remains independent of $\tau$. But in the NonIID setting, the ratio decreases to zero as $\tau$ increases for misaligned filters ($r = 1,2,4,5$).
  • Figure 4: Initial alignment of the filters in \ref{['fig:two_layer_cnn_expts']}.
  • Figure 5: The percentage of misaligned filters (see \ref{['eq:empirical_alignment']} and test accuracy for different initializations on CIFAR-10 (\ref{['subfig:filter_misalign_cifar10_0.3']} and \ref{['subfig:test_acc_cifar10_0.3']}) and TinyImageNet (\ref{['subfig:filter_misalign_tinyimagenet']} and \ref{['subfig:test_accuracy_tinyimageent']}). As the complexity of the signal information in the data grows from CIFAR-10 to TinyImageNet, we see a sharp increase in the ratio of misaligned filters for random initialization, explaining why pre-trained initialization offers larger improvements for TinyImageNet.
  • ...and 1 more figures

Theorems & Definitions (80)

  • Definition 1
  • Proposition 1
  • Theorem 1: Training Loss Convergence
  • Theorem 2: Test Error Bound
  • Lemma 1
  • Lemma 2
  • Lemma 3: All Filters Aligned After Sufficient Training
  • Lemma 4
  • Theorem : Training Loss Convergence
  • Lemma 5
  • ...and 70 more