Table of Contents
Fetching ...

Minimizing Layerwise Activation Norm Improves Generalization in Federated Learning

M Yashwanth, Gaurav Kumar Nayak, Harsh Rangwani, Arya Singh, R. Venkatesh Babu, Anirban Chakraborty

TL;DR

This work addresses generalization in Federated Learning by introducing a flatness-aware objective that penalizes Hessian-based sharpness. The authors derive a computationally efficient regularizer, MAN, which minimizes the layerwise activation norms to bound the layerwise Hessian top eigenvalues, and they show that this promotes flat minima both locally and globally. The method integrates with standard FL algorithms (e.g., FedAvg, FedDC, FedDyn) and remains cost-efficient compared to SAM-based approaches, while yielding notable improvements in accuracy and reduced communication rounds on CIFAR-100 and Tiny-ImageNet under non-iid settings. Theoretical guarantees tie activation-norm minimization to Hessian flatness, and extensive experiments corroborate flatter minima via Hessian analysis and improved generalization across datasets and heterogeneity levels.

Abstract

Federated Learning (FL) is an emerging machine learning framework that enables multiple clients (coordinated by a server) to collaboratively train a global model by aggregating the locally trained models without sharing any client's training data. It has been observed in recent works that learning in a federated manner may lead the aggregated global model to converge to a 'sharp minimum' thereby adversely affecting the generalizability of this FL-trained model. Therefore, in this work, we aim to improve the generalization performance of models trained in a federated setup by introducing a 'flatness' constrained FL optimization problem. This flatness constraint is imposed on the top eigenvalue of the Hessian computed from the training loss. As each client trains a model on its local data, we further re-formulate this complex problem utilizing the client loss functions and propose a new computationally efficient regularization technique, dubbed 'MAN,' which Minimizes Activation's Norm of each layer on client-side models. We also theoretically show that minimizing the activation norm reduces the top eigenvalue of the layer-wise Hessian of the client's loss, which in turn decreases the overall Hessian's top eigenvalue, ensuring convergence to a flat minimum. We apply our proposed flatness-constrained optimization to the existing FL techniques and obtain significant improvements, thereby establishing new state-of-the-art.

Minimizing Layerwise Activation Norm Improves Generalization in Federated Learning

TL;DR

This work addresses generalization in Federated Learning by introducing a flatness-aware objective that penalizes Hessian-based sharpness. The authors derive a computationally efficient regularizer, MAN, which minimizes the layerwise activation norms to bound the layerwise Hessian top eigenvalues, and they show that this promotes flat minima both locally and globally. The method integrates with standard FL algorithms (e.g., FedAvg, FedDC, FedDyn) and remains cost-efficient compared to SAM-based approaches, while yielding notable improvements in accuracy and reduced communication rounds on CIFAR-100 and Tiny-ImageNet under non-iid settings. Theoretical guarantees tie activation-norm minimization to Hessian flatness, and extensive experiments corroborate flatter minima via Hessian analysis and improved generalization across datasets and heterogeneity levels.

Abstract

Federated Learning (FL) is an emerging machine learning framework that enables multiple clients (coordinated by a server) to collaboratively train a global model by aggregating the locally trained models without sharing any client's training data. It has been observed in recent works that learning in a federated manner may lead the aggregated global model to converge to a 'sharp minimum' thereby adversely affecting the generalizability of this FL-trained model. Therefore, in this work, we aim to improve the generalization performance of models trained in a federated setup by introducing a 'flatness' constrained FL optimization problem. This flatness constraint is imposed on the top eigenvalue of the Hessian computed from the training loss. As each client trains a model on its local data, we further re-formulate this complex problem utilizing the client loss functions and propose a new computationally efficient regularization technique, dubbed 'MAN,' which Minimizes Activation's Norm of each layer on client-side models. We also theoretically show that minimizing the activation norm reduces the top eigenvalue of the layer-wise Hessian of the client's loss, which in turn decreases the overall Hessian's top eigenvalue, ensuring convergence to a flat minimum. We apply our proposed flatness-constrained optimization to the existing FL techniques and obtain significant improvements, thereby establishing new state-of-the-art.

Paper Structure

This paper contains 28 sections, 9 theorems, 64 equations, 7 figures, 6 tables.

Key Result

Theorem 1

If $\mathbf{H}_{ll} \in R^{d_l}$ denotes the layer $l$ Hessian and $\mathbf{H} \in R^{d}$ denotes the over all Hessian and $\sum_{l=1}^{L}{d_l} = d$, where $L$ is the total number of layers. If the Hessian entries are bounded above we then have the following result. $\lambda(\mathbf{H}) \in \cup_{l=

Figures (7)

  • Figure 1: We plot the loss surface of the global model trained on CIFAR-100 using FedAvg in \ref{['fedavg_hess']}. In Fig \ref{['man_comp']} we show MAN regularizer. We combine FedAvg with MAN (FedAvg+MAN) to obtain the flat loss surface in Fig \ref{['fedavgreg_hess']} which has better generalization.
  • Figure 2: Convergence Comparison on CIFAR-100: We compare performance of the algorithms FedAvg, FedDyn, FedDC and the proposed FedAvg+MAN, FedDyn+MAN and FedDC+MAN for 500 communication rounds. It can be clearly seen that proposed approach significantly improves the existing algorithms across the communication rounds.
  • Figure 3: Convergence Comparison on Tiny-ImageNet: We compare the performance of the algorithms FedAvg, FedDyn, FedDC and the proposed FedAvg+MAN, FedDyn+MAN, and FedDC+MAN for 500 communication rounds. It can be clearly seen that the proposed approach significantly improves the existing algorithms.
  • Figure 4: Figure \ref{['fig:Ng1']} shows the comparison of top 200 eigenvalues of FedAvg/FedAvg+MAN and FedSAM/FedSAM+MAN. It can be seen MAN regularizer reduces the top eigenvalues. This explains the reduction of the trace observed. From the Figure \ref{['fig:Ng2']} we see negative eigenvalues contributes little to the trace.
  • Figure 5: Sensitivity of Accuracy to the hyper-parameter $\zeta$. It can be seen that accuracy is stable over $\zeta \in \{ 0.1, 3.0\}$
  • ...and 2 more figures

Theorems & Definitions (14)

  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Proposition 1
  • Theorem 4
  • proof
  • Lemma 1
  • proof
  • Theorem 5
  • proof
  • ...and 4 more