Table of Contents
Fetching ...

Optimal Transport Aggregation for Distributed Mixture-of-Experts

Faïcel Chamroukhi, Nhat Thien Pham

TL;DR

A principled aggregation framework based on optimal transport is proposed that constructs a reduced global MoE estimator by minimizing a transportation divergence between the collection of local estimators and the aggregated model, and an efficient majorization--minimization algorithm is derived to solve the resulting optimization problem.

Abstract

Mixture-of-experts (MoE) models provide a flexible statistical framework for modeling heterogeneity and nonlinear relationships. In many modern applications, however, datasets are naturally distributed across multiple machines due to storage, computational, or governance constraints. We consider a distributed model aggregation setting in which local MoE models are trained independently on decentralized datasets and subsequently combined into a global estimator. Aggregating MoE models is challenging because standard averaging produces models that do not preserve the MoE structure, and therefore do not yield estimates of the global model parameters. To address this issue, we propose a principled aggregation framework based on optimal transport that constructs a reduced global MoE estimator by minimizing a transportation divergence between the collection of local estimators and the aggregated model. An efficient majorization--minimization (MM) algorithm is derived to solve the resulting optimization problem. The method requires only a single communication step from local machines to a central server, making it a frugal distributed learning approach particularly attractive for large-scale settings where communication costs are a major bottleneck. We further establish statistical guarantees for the aggregated estimator, including consistency under standard assumptions on the local estimators. Experiments on synthetic and real datasets demonstrate that the approach achieves performance comparable to centralized training while significantly reducing computation time. The source codes are publicly available on Github.

Optimal Transport Aggregation for Distributed Mixture-of-Experts

TL;DR

A principled aggregation framework based on optimal transport is proposed that constructs a reduced global MoE estimator by minimizing a transportation divergence between the collection of local estimators and the aggregated model, and an efficient majorization--minimization algorithm is derived to solve the resulting optimization problem.

Abstract

Mixture-of-experts (MoE) models provide a flexible statistical framework for modeling heterogeneity and nonlinear relationships. In many modern applications, however, datasets are naturally distributed across multiple machines due to storage, computational, or governance constraints. We consider a distributed model aggregation setting in which local MoE models are trained independently on decentralized datasets and subsequently combined into a global estimator. Aggregating MoE models is challenging because standard averaging produces models that do not preserve the MoE structure, and therefore do not yield estimates of the global model parameters. To address this issue, we propose a principled aggregation framework based on optimal transport that constructs a reduced global MoE estimator by minimizing a transportation divergence between the collection of local estimators and the aggregated model. An efficient majorization--minimization (MM) algorithm is derived to solve the resulting optimization problem. The method requires only a single communication step from local machines to a central server, making it a frugal distributed learning approach particularly attractive for large-scale settings where communication costs are a major bottleneck. We further establish statistical guarantees for the aggregated estimator, including consistency under standard assumptions on the local estimators. Experiments on synthetic and real datasets demonstrate that the approach achieves performance comparable to centralized training while significantly reducing computation time. The source codes are publicly available on Github.
Paper Structure (37 sections, 6 theorems, 94 equations, 5 figures, 1 table, 1 algorithm)

This paper contains 37 sections, 6 theorems, 94 equations, 5 figures, 1 table, 1 algorithm.

Key Result

proposition 1

Let $\mathcal{T}_c(g), \mathcal{R}_c(g)$ and $\mathcal{P}(g,\mathbf{x})$ be defined as above. Then $\underset{g\in\mathcal{M}_K}{\inf}\mathcal{T}_{c}(g) = \underset{g\in\mathcal{M}_K}{\inf}\mathcal{R}_c(g).$ The reduction solution is hence given by $\bar{f}^R = \underset{g\in\mathcal{M}_K}{\arg\inf where $\mathcal{P}_{\ell k}(\bar{f}^R, \mathbf{x})$ denotes the entry $(\ell,k)$ of $\mathcal{P}(\b

Figures (5)

  • Figure 1: Performance of the Global MoE (G), Reduction (R), Middle (M) and Weighted average (W) estimators for sample size $N=10^6$ and $M$ machines.
  • Figure 2: Performance of the Global (G), Reduction (R), Middle (M) and Weighted average (W) estimators using 128 machines for different sample sizes $N$.
  • Figure 3: Performance of the Global (G), Reduction (R), Middle (M), and Weighted (W) estimators for sample size $N=10^5$ and different numbers of machines $M$.
  • Figure 4: Performance of Global (G), Reduction (R), Middle (M), and Weighted (W) estimators for sample size $N=5\times10^5$ and numbers of machines $M$.
  • Figure 5: Evolution of the objective function across iterations, illustrating the monotonic decrease and convergence behavior of the MM algorithm.

Theorems & Definitions (7)

  • definition 1: Expected transportation divergence
  • proposition 1
  • proposition 2
  • theorem 1: Consistency of the reduction estimator
  • proposition 3
  • proposition 4
  • proposition 5