Table of Contents
Fetching ...

Addressing Label Shift in Distributed Learning via Entropy Regularization

Zhiyuan Wu, Changkyu Choi, Xiangcheng Cao, Volkan Cevher, Ali Ramezani-Kebrya

TL;DR

The paper addresses label shift in distributed learning where data stays on client nodes and proposes VRLS, an entropy-regularized density-ratio estimator, integrated into an IW-ERM framework to counter both intra- and inter-node shifts. VRLS improves calibration of the estimated $p^{\text{tr}}(\boldsymbol{y}|\boldsymbol{x})$, enabling more accurate density ratios $p^{\text{te}}(\boldsymbol{y})/p^{\text{tr}}(\boldsymbol{y})$ and near-optimal global risk in multi-node settings. The authors provide finite-sample error bounds for the ratio estimates and convergence guarantees for IW-ERM under convex, smooth, and nonconvex regimes, while preserving privacy and minimizing communication. Empirically, VRLS-based IW-ERM achieves up to 20% improvements in test error in imbalanced label-shift scenarios on MNIST, Fashion-MNIST, and CIFAR-10, and scales to 5–200 nodes with results approaching an upper bound that uses true density ratios, highlighting practical impact for robust distributed learning.

Abstract

We address the challenge of minimizing true risk in multi-node distributed learning. These systems are frequently exposed to both inter-node and intra-node label shifts, which present a critical obstacle to effectively optimizing model performance while ensuring that data remains confined to each node. To tackle this, we propose the Versatile Robust Label Shift (VRLS) method, which enhances the maximum likelihood estimation of the test-to-train label density ratio. VRLS incorporates Shannon entropy-based regularization and adjusts the density ratio during training to better handle label shifts at the test time. In multi-node learning environments, VRLS further extends its capabilities by learning and adapting density ratios across nodes, effectively mitigating label shifts and improving overall model performance. Experiments conducted on MNIST, Fashion MNIST, and CIFAR-10 demonstrate the effectiveness of VRLS, outperforming baselines by up to 20% in imbalanced settings. These results highlight the significant improvements VRLS offers in addressing label shifts. Our theoretical analysis further supports this by establishing high-probability bounds on estimation errors.

Addressing Label Shift in Distributed Learning via Entropy Regularization

TL;DR

The paper addresses label shift in distributed learning where data stays on client nodes and proposes VRLS, an entropy-regularized density-ratio estimator, integrated into an IW-ERM framework to counter both intra- and inter-node shifts. VRLS improves calibration of the estimated , enabling more accurate density ratios and near-optimal global risk in multi-node settings. The authors provide finite-sample error bounds for the ratio estimates and convergence guarantees for IW-ERM under convex, smooth, and nonconvex regimes, while preserving privacy and minimizing communication. Empirically, VRLS-based IW-ERM achieves up to 20% improvements in test error in imbalanced label-shift scenarios on MNIST, Fashion-MNIST, and CIFAR-10, and scales to 5–200 nodes with results approaching an upper bound that uses true density ratios, highlighting practical impact for robust distributed learning.

Abstract

We address the challenge of minimizing true risk in multi-node distributed learning. These systems are frequently exposed to both inter-node and intra-node label shifts, which present a critical obstacle to effectively optimizing model performance while ensuring that data remains confined to each node. To tackle this, we propose the Versatile Robust Label Shift (VRLS) method, which enhances the maximum likelihood estimation of the test-to-train label density ratio. VRLS incorporates Shannon entropy-based regularization and adjusts the density ratio during training to better handle label shifts at the test time. In multi-node learning environments, VRLS further extends its capabilities by learning and adapting density ratios across nodes, effectively mitigating label shifts and improving overall model performance. Experiments conducted on MNIST, Fashion MNIST, and CIFAR-10 demonstrate the effectiveness of VRLS, outperforming baselines by up to 20% in imbalanced settings. These results highlight the significant improvements VRLS offers in addressing label shifts. Our theoretical analysis further supports this by establishing high-probability bounds on estimation errors.

Paper Structure

This paper contains 32 sections, 17 theorems, 54 equations, 15 figures, 10 tables, 2 algorithms.

Key Result

Proposition 4.1

Under the label shift setting described in sec:intro, equation IWERM:gen;R is consistent and the learned function $h_{\boldsymbol{w}}$ converges in probability towards the optimal function that minimizes the overall true risk across nodes, $\sum_{k=1}^K R_k$.

Figures (15)

  • Figure 3: MSE analysis on MNIST for MLLS baselines. Left: Performance evaluation across various alpha values, comparing different methods: MLLS_EM, MLLS_L1, MLLS_L2, and MLLS_CG. MLLS_L1 and MLLS_L2 utilize convex optimization with $L_1$ and $L_2$ regularization for estimating our limited test sample problem, respectively, and are solved directly with a convex solver. In contrast, MLLS_CG uses conjugate gradient descent and MLLS_EM solves this convex optimization problem with EM algorithm. Both the EM and convex optimization methods (MLLS_L1, MLLS_L2) demonstrate superior and more consistent performance, especially under severe label shift conditions, when compared to MLLS_CG. Middle: At an alpha value of 1.0, the MSE analysis shows comparable performance across most methods, with the exception of MLLS_CG, which lags behind. Right: For alpha=0.1, MLLS_CG performs significantly worse than the EM and convex optimization methods, consistent with the trends observed in the left plot.
  • Figure 4: In our detailed analysis with the MNIST dataset, we conduct a thorough comparison of VRLS alongside MLLS mlls, EM bbse_2002, and also RLLS rlls.
  • Figure 5: In this experiment with Fashion MNIST, a simple MLP with dropout were employed.
  • Figure 6: The average, best-client, and worst-client accuracy, along with their standard deviations, are derived from \ref{['app:fig:label-shift:fmnist:table']}. Our method exhibits the lowest standard deviation, showcasing the most robust accuracy amongst the compared methods.
  • Figure 7: The average, best-client, and worst-client accuracy, along with their standard deviations, are derived from \ref{['app:fig:label-shift:cifar10:table']}.
  • ...and 10 more figures

Theorems & Definitions (29)

  • Proposition 4.1
  • proof
  • Theorem 5.1: Ratio Estimation Error Bound
  • proof
  • Theorem 5.2: Convergence-communication
  • Theorem 5.3: Upper Bound for Convex and Smooth
  • Theorem 5.4: Lower Bound for Convex and Second-order Differentiable
  • Theorem 5.5: High-probability Bound for Nonconvex Optimization
  • proof
  • proof
  • ...and 19 more