Table of Contents
Fetching ...

UniMoD: Efficient Unified Multimodal Transformers with Mixture-of-Depths

Weijia Mao, Zhenheng Yang, Mike Zheng Shou

TL;DR

UniMoD addresses the high computational cost of unified multimodal transformers by introducing a task-aware token pruning framework. By analyzing attention patterns, layer importance, and cross-task interactions, it designs per-task MoD blocks and a layer-switching strategy guided by ARank to prune tokens adaptively. The approach yields substantial FLOP savings (e.g., ~15% in Show-o and ~40% in Emu3) with maintained or improved benchmark performance across MMU and T2I tasks and is extendable to diffusion-based generation models. This technique offers a practical path to efficient training of versatile multimodal models without sacrificing effectiveness.

Abstract

Unified multimodal transformers, which handle both generation and understanding tasks within a shared parameter space, have received increasing attention in recent research. Although various unified transformers have been proposed, training these models is costly due to redundant tokens and heavy attention computation. In the past, studies on large language models have demonstrated that token pruning methods, such as Mixture of Depths (MoD), can significantly improve computational efficiency. MoD employs a router to select the most important ones for processing within a transformer layer. However, directly applying MoD-based token pruning to unified transformers will result in suboptimal performance because different tasks exhibit varying levels of token redundancy. In our work, we analyze the unified transformers by (1) examining attention weight patterns, (2) evaluating the layer importance and token redundancy, and (3) analyzing task interactions. Our findings reveal that token redundancy is primarily influenced by different tasks and layers. Building on these findings, we introduce UniMoD, a task-aware token pruning method that employs a separate router for each task to determine which tokens should be pruned. We apply our method to Show-o and Emu3, reducing training FLOPs by approximately 15% in Show-o and 40% in Emu3, while maintaining or improving performance on several benchmarks. Code will be released at https://github.com/showlab/UniMoD.

UniMoD: Efficient Unified Multimodal Transformers with Mixture-of-Depths

TL;DR

UniMoD addresses the high computational cost of unified multimodal transformers by introducing a task-aware token pruning framework. By analyzing attention patterns, layer importance, and cross-task interactions, it designs per-task MoD blocks and a layer-switching strategy guided by ARank to prune tokens adaptively. The approach yields substantial FLOP savings (e.g., ~15% in Show-o and ~40% in Emu3) with maintained or improved benchmark performance across MMU and T2I tasks and is extendable to diffusion-based generation models. This technique offers a practical path to efficient training of versatile multimodal models without sacrificing effectiveness.

Abstract

Unified multimodal transformers, which handle both generation and understanding tasks within a shared parameter space, have received increasing attention in recent research. Although various unified transformers have been proposed, training these models is costly due to redundant tokens and heavy attention computation. In the past, studies on large language models have demonstrated that token pruning methods, such as Mixture of Depths (MoD), can significantly improve computational efficiency. MoD employs a router to select the most important ones for processing within a transformer layer. However, directly applying MoD-based token pruning to unified transformers will result in suboptimal performance because different tasks exhibit varying levels of token redundancy. In our work, we analyze the unified transformers by (1) examining attention weight patterns, (2) evaluating the layer importance and token redundancy, and (3) analyzing task interactions. Our findings reveal that token redundancy is primarily influenced by different tasks and layers. Building on these findings, we introduce UniMoD, a task-aware token pruning method that employs a separate router for each task to determine which tokens should be pruned. We apply our method to Show-o and Emu3, reducing training FLOPs by approximately 15% in Show-o and 40% in Emu3, while maintaining or improving performance on several benchmarks. Code will be released at https://github.com/showlab/UniMoD.

Paper Structure

This paper contains 22 sections, 8 equations, 5 figures, 7 tables.

Figures (5)

  • Figure 1: (a) Pipeline and challenges of applying Mixture of Depths (MoD) to unified transformers. A single router prunes tokens across tasks and layers, leading to suboptimal performance due to inconsistent token redundancy. (b) Two key observations from our experiments on unified transformers, providing critical insights for our proposed method.
  • Figure 2: Attention weight for text and image tokens across different transformer layers for two tasks: Multi-Modal Understanding (MMU, top row) and Text-to-Image generation (T2I, bottom row). Each curve represents one token type, showing how attention allocation changes with model, layer index, and task.
  • Figure 3: ARank variations across different layers for four unified transformers: Show-o, JanusFlow, Emu3, and Lumina-mgpt. ARank, defined as the rank of the attention map, represents sequence redundancy within each layer. Higher ARank values indicate lower sequence redundancy within each layer.
  • Figure 4: Pipeline of UniMoD. The Layer Switch Module transforms dense transformer layers into three specialized types: T2I MoD layers for Text-to-Image (T2I) generation, MMU MoD layers for Multi-Modal Understanding (MMU), and Shared MoD layers for both tasks. For each task, task-aware routers with distinct capacities prune tokens of different modalities, thereby enhancing computational efficiency and maintaining performance across tasks.
  • Figure 5: Token weight assignment using Gumbel Softmax. A higher number of tokens assigned a weight of 1 across layers indicates greater importance of generation task tokens compared to understanding task tokens.