Table of Contents
Fetching ...

Mixture-of-Transformers Learn Faster: A Theoretical Study on Classification Problems

Hongbo Li, Qinhang Wu, Sen Lin, Yingbin Liang, Ness B. Shroff

TL;DR

The Mixture-of-Transformers (MoT), a tractable theoretical framework in which each transformer block acts as an expert governed by a continuously trained gating network, is studied, offering the first unified theoretical account of transformer-level specialization and learning dynamics.

Abstract

Mixture-of-Experts (MoE) models improve transformer efficiency but lack a unified theoretical explanation, especially when both feed-forward and attention layers are allowed to specialize. To this end, we study the Mixture-of-Transformers (MoT), a tractable theoretical framework in which each transformer block acts as an expert governed by a continuously trained gating network. This design allows us to isolate and study the core learning dynamics of expert specialization and attention alignment. In particular, we develop a three-stage training algorithm with continuous training of the gating network, and show that each transformer expert specializes in a distinct class of tasks and that the gating network accurately routes data samples to the correct expert. Our analysis shows how expert specialization reduces gradient conflicts and makes each subtask strongly convex. We prove that the training drives the expected prediction loss to near zero in $O(\log(ε^{-1}))$ iteration steps, significantly improving over the $O(ε^{-1})$ rate for a single transformer. We further validate our theoretical findings through extensive real-data experiments, demonstrating the practical effectiveness of MoT. Together, these results offer the first unified theoretical account of transformer-level specialization and learning dynamics, providing practical guidance for designing efficient large-scale models.

Mixture-of-Transformers Learn Faster: A Theoretical Study on Classification Problems

TL;DR

The Mixture-of-Transformers (MoT), a tractable theoretical framework in which each transformer block acts as an expert governed by a continuously trained gating network, is studied, offering the first unified theoretical account of transformer-level specialization and learning dynamics.

Abstract

Mixture-of-Experts (MoE) models improve transformer efficiency but lack a unified theoretical explanation, especially when both feed-forward and attention layers are allowed to specialize. To this end, we study the Mixture-of-Transformers (MoT), a tractable theoretical framework in which each transformer block acts as an expert governed by a continuously trained gating network. This design allows us to isolate and study the core learning dynamics of expert specialization and attention alignment. In particular, we develop a three-stage training algorithm with continuous training of the gating network, and show that each transformer expert specializes in a distinct class of tasks and that the gating network accurately routes data samples to the correct expert. Our analysis shows how expert specialization reduces gradient conflicts and makes each subtask strongly convex. We prove that the training drives the expected prediction loss to near zero in iteration steps, significantly improving over the rate for a single transformer. We further validate our theoretical findings through extensive real-data experiments, demonstrating the practical effectiveness of MoT. Together, these results offer the first unified theoretical account of transformer-level specialization and learning dynamics, providing practical guidance for designing efficient large-scale models.

Paper Structure

This paper contains 21 sections, 20 theorems, 85 equations, 9 figures, 1 table, 1 algorithm.

Key Result

Proposition 1

Under algo:update_MoT, for any epoch $t\geq T_1$, where $T_1=\mathcal{O}(\eta^{-1} \sigma_0^{-0.5} M)$ with $\sigma_0=\mathcal{O}(1)$ and $M=\Omega(N\log(N))$, with probability at least $1-o(1)$, the following holds: Moreover, for any input $\bm{X}^{(k)}$ that includes class signal $c_{n_i^*}$, the router selects each expert $i\in\mathcal{M}_{n_i^*}$ with equal probability $\frac{1}{|\mathcal{M}_

Figures (9)

  • Figure 1: Illustrations of (a) the MoT model, (b) multi-head transformer without gating/router yang2025transformers, and (c) attention absent MoE chen2022towardsli2025theory.
  • Figure 2: The comparison among Three-stage MoT and other baselines regarding prediction error across epochs: (1) an attention-absent MoE model with $M=5$ experts, and (2) a multi-head transformer without gating. For CIFAR-10 (\ref{['subfig:exp_g1_c10']}), we set $T_1=50,T_2=450,T=600$ and $M=5$. For CIFAR-100 (\ref{['subfig:exp_g1_c100']}), we set $T_1=50,T_2=250,T=300$, and vary $M\in\{12,16\}$.
  • Figure 3: Routing history visualization for the MoT model on CIFAR-10 (left, epoch $580$, $M=5$) and CIFAR-100 (right, epoch $280$, $M=12$). Darker blue in each column indicates a higher proportion of samples from a class routed to that expert, showing expert specialization.
  • Figure 4: The comparison among Three-stage MoT, multi-head transformer, and FFN-level MoE (MoE-FFN) regarding prediction error across epochs on CIFAR-100. We set $T_1=50,T_2=300,T=500$. Both MoT and MoE-FFN has $M=8$ experts (\ref{['subfig:exp_x1_m8']}) and $M=12$ experts (\ref{['subfig:exp_g1_c100']}).
  • Figure 5: The comparison between Three-stage MoT ($M=2,3$) and multi-head transformer on Amazon Polarity Classification dataset. We set $T_1=50,T_2=200,T=300$.
  • ...and 4 more figures

Theorems & Definitions (37)

  • Definition 1: $N$-mixture of classification
  • Definition 2: Transformer specialization
  • Proposition 1: FFN specialization and router convergence
  • Proposition 2: Attention training
  • Theorem 1: Global Convergence of MoT
  • Lemma 1: Convergence of Attention-Absent MoE
  • Lemma 2: Expression of $\nabla_{\bm{\theta}^{(i)}} \mathcal{L}^r (\bm{X};\bm{\Theta})$
  • proof
  • Corollary 1
  • proof
  • ...and 27 more