MoEUT: Mixture-of-Experts Universal Transformers
Róbert Csordás, Kazuki Irie, Jürgen Schmidhuber, Christopher Potts, Christopher D. Manning
TL;DR
This work tackles the inefficiency of shared-layer Universal Transformers by introducing MoEUT, a mixture-of-experts UT that decouples parameters from compute. It combines sigma-MoE FFN blocks and SwitchHead MoE attention with two innovations—layer grouping and peri-layernorm—to enable scalable, parameter-efficient UTs. Empirical results on C4, SlimPajama, peS2o, and The Stack show MoEUT matching or slightly surpassing parameter-matched dense Transformers with reduced compute and memory, and strong zero-shot performance. The paper also provides thorough analyses of expert routing and discusses limitations and avenues for massive scaling, suggesting a path for reviving UT research at large scales.
Abstract
Previous work on Universal Transformers (UTs) has demonstrated the importance of parameter sharing across layers. By allowing recurrence in depth, UTs have advantages over standard Transformers in learning compositional generalizations, but layer-sharing comes with a practical limitation of parameter-compute ratio: it drastically reduces the parameter count compared to the non-shared model with the same dimensionality. Naively scaling up the layer size to compensate for the loss of parameters makes its computational resource requirements prohibitive. In practice, no previous work has succeeded in proposing a shared-layer Transformer design that is competitive in parameter count-dominated tasks such as language modeling. Here we propose MoEUT (pronounced "moot"), an effective mixture-of-experts (MoE)-based shared-layer Transformer architecture, which combines several recent advances in MoEs for both feedforward and attention layers of standard Transformers together with novel layer-normalization and grouping schemes that are specific and crucial to UTs. The resulting UT model, for the first time, slightly outperforms standard Transformers on language modeling tasks such as BLiMP and PIQA, while using significantly less compute and memory.
