Table of Contents
Fetching ...

Who to Trust? Aggregating Client Predictions in Federated Distillation

Viktor Kovalchuk, Denis Son, Arman Bolatov, Mohsen Guizani, Samuel Horváth, Maxim Panov, Martin Takáč, Eduard Gorbunov, Nikita Kotelevskii

Abstract

Under data heterogeneity (e.g., $\textit{class mismatch}$), clients may produce unreliable predictions for instances belonging to unfamiliar classes. An equally weighted combination of such predictions can corrupt the teacher signal used for distillation. In this paper, we provide a theoretical analysis of Federated Distillation and show that aggregating client predictions on a shared public dataset converges to a neighborhood of the optimum, where the neighborhood size is governed by the aggregation quality. We further propose two uncertainty-aware aggregation methods, $\mathbf{UWA}$ and $\mathbf{sUWA}$, which leverage density-based uncertainty estimates to down-weight unreliable client predictions. Experiments on image and text classification benchmarks demonstrate that our methods are particularly effective under high data heterogeneity, while matching standard averaging when heterogeneity is low.

Who to Trust? Aggregating Client Predictions in Federated Distillation

Abstract

Under data heterogeneity (e.g., ), clients may produce unreliable predictions for instances belonging to unfamiliar classes. An equally weighted combination of such predictions can corrupt the teacher signal used for distillation. In this paper, we provide a theoretical analysis of Federated Distillation and show that aggregating client predictions on a shared public dataset converges to a neighborhood of the optimum, where the neighborhood size is governed by the aggregation quality. We further propose two uncertainty-aware aggregation methods, and , which leverage density-based uncertainty estimates to down-weight unreliable client predictions. Experiments on image and text classification benchmarks demonstrate that our methods are particularly effective under high data heterogeneity, while matching standard averaging when heterogeneity is low.

Paper Structure

This paper contains 59 sections, 8 theorems, 79 equations, 7 figures, 2 tables, 1 algorithm.

Key Result

Lemma 4.1

Consider probability-mixing aggregation $p_{\mathrm{agg}}(x)=\sum_{i\in\mathcal{S}} w_i(x)q_i(x)$ and assume the boundedness of its bias by $B_w$ as well as the boundedness of variance and correlation of $q_i(x) - p^*(x)$ by $\sigma^2$ and $\rho_c\sigma^2$ respectively (see Assumption ass:teacher). Then, the teacher MSE satisfies:

Figures (7)

  • Figure 1: Our two-stage training loop works as follows. Stage one. (1) Clients train on their private labeled data. (2) Clients compute predictions on the public unlabeled dataset. (3) Clients send predictions to the server. Stage two. (4) The server aggregates predictions into soft labels. (5) The soft labels are sent back to clients for refinement.
  • Figure 2: Best test accuracy vs. classes per client ($k$). Dashed line: fully-informed reference model trained on all-classes dataset with size $|\mathcal{D}_{pub}| + |\mathcal{D}_{i}|$).
  • Figure 3: Mean accuracy on local classes over communication rounds.
  • Figure 4: Yahoo Answers: global (top) and local (bottom) test accuracy across communication rounds for $k \in \{2,3,4,5,7,9\}$.
  • Figure 5: CIFAR-10: global (top) and local (bottom) test accuracy across communication rounds for $k \in \{2,3,4,5,7,9\}$.
  • ...and 2 more figures

Theorems & Definitions (18)

  • Definition 3.1
  • Lemma 4.1: Teacher MSE under bias-variance-correlation
  • Proposition 4.2: Expected-oracle conditions adaptation
  • Corollary 4.3: Big-$\mathcal{O}$ complexity for nonconvex stationarity
  • Corollary 4.4: Big-$\mathcal{O}$ complexity under PL
  • Remark C.5
  • Lemma C.7: Conditional identity
  • proof
  • proof
  • Lemma C.8: Aggregation-induced gradient discrepancy
  • ...and 8 more