Table of Contents
Fetching ...

MPDiT: Multi-Patch Global-to-Local Transformer Architecture For Efficient Flow Matching and Diffusion Model

Quan Dao, Dimitris Metaxas

Abstract

Transformer architectures, particularly Diffusion Transformers (DiTs), have become widely used in diffusion and flow-matching models due to their strong performance compared to convolutional UNets. However, the isotropic design of DiTs processes the same number of patchified tokens in every block, leading to relatively heavy computation during training process. In this work, we introduce a multi-patch transformer design in which early blocks operate on larger patches to capture coarse global context, while later blocks use smaller patches to refine local details. This hierarchical design could reduces computational cost by up to 50\% in GFLOPs while achieving good generative performance. In addition, we also propose improved designs for time and class embeddings that accelerate training convergence. Extensive experiments on the ImageNet dataset demonstrate the effectiveness of our architectural choices. Code is released at \url{https://github.com/quandao10/MPDiT}

MPDiT: Multi-Patch Global-to-Local Transformer Architecture For Efficient Flow Matching and Diffusion Model

Abstract

Transformer architectures, particularly Diffusion Transformers (DiTs), have become widely used in diffusion and flow-matching models due to their strong performance compared to convolutional UNets. However, the isotropic design of DiTs processes the same number of patchified tokens in every block, leading to relatively heavy computation during training process. In this work, we introduce a multi-patch transformer design in which early blocks operate on larger patches to capture coarse global context, while later blocks use smaller patches to refine local details. This hierarchical design could reduces computational cost by up to 50\% in GFLOPs while achieving good generative performance. In addition, we also propose improved designs for time and class embeddings that accelerate training convergence. Extensive experiments on the ImageNet dataset demonstrate the effectiveness of our architectural choices. Code is released at \url{https://github.com/quandao10/MPDiT}

Paper Structure

This paper contains 33 sections, 2 equations, 14 figures, 9 tables.

Figures (14)

  • Figure 1: The generated samples from MPDiT-XL with the cfg-scale $w=3$ at epoch 160.
  • Figure 2: Architecture of MPDiT, which consists of (a) the Global-Local MultiPatch Diffusion Transformer, (b) DiT Block with shared time embedding, (c) The Upsample Module and (d) The FNO Time Embedding
  • Figure 3: Qualitative Result of Imagenet 512 with cfg=4
  • Figure 4: Qualitative images of class 113 "snail"
  • Figure 5: Qualitative images of class 33 "loggerhead, loggerhead turtle, Caretta caretta"
  • ...and 9 more figures