Table of Contents
Fetching ...

Accelerating Masked Image Generation by Learning Latent Controlled Dynamics

Kaiwen Zhu, Quansheng Zeng, Yuandong Pu, Shuo Cao, Xiaohui Li, Yi Xin, Qi Qin, Jiayang Li, Yu Qiao, Jinjin Gu, Yihao Liu

TL;DR

This work proposes to learn a lightweight model that incorporates both previous features and sampled tokens, and regresses the average velocity field of feature evolution, and applies this method, MIGM-Shortcut, to two representative MIGM architectures and tasks.

Abstract

Masked Image Generation Models (MIGMs) have achieved great success, yet their efficiency is hampered by the multiple steps of bi-directional attention. In fact, there exists notable redundancy in their computation: when sampling discrete tokens, the rich semantics contained in the continuous features are lost. Some existing works attempt to cache the features to approximate future features. However, they exhibit considerable approximation error under aggressive acceleration rates. We attribute this to their limited expressivity and the failure to account for sampling information. To fill this gap, we propose to learn a lightweight model that incorporates both previous features and sampled tokens, and regresses the average velocity field of feature evolution. The model has moderate complexity that suffices to capture the subtle dynamics while keeping lightweight compared to the original base model. We apply our method, MIGM-Shortcut, to two representative MIGM architectures and tasks. In particular, on the state-of-the-art Lumina-DiMOO, it achieves over 4x acceleration of text-to-image generation while maintaining quality, significantly pushing the Pareto frontier of masked image generation. The code and model weights are available at https://github.com/Kaiwen-Zhu/MIGM-Shortcut.

Accelerating Masked Image Generation by Learning Latent Controlled Dynamics

TL;DR

This work proposes to learn a lightweight model that incorporates both previous features and sampled tokens, and regresses the average velocity field of feature evolution, and applies this method, MIGM-Shortcut, to two representative MIGM architectures and tasks.

Abstract

Masked Image Generation Models (MIGMs) have achieved great success, yet their efficiency is hampered by the multiple steps of bi-directional attention. In fact, there exists notable redundancy in their computation: when sampling discrete tokens, the rich semantics contained in the continuous features are lost. Some existing works attempt to cache the features to approximate future features. However, they exhibit considerable approximation error under aggressive acceleration rates. We attribute this to their limited expressivity and the failure to account for sampling information. To fill this gap, we propose to learn a lightweight model that incorporates both previous features and sampled tokens, and regresses the average velocity field of feature evolution. The model has moderate complexity that suffices to capture the subtle dynamics while keeping lightweight compared to the original base model. We apply our method, MIGM-Shortcut, to two representative MIGM architectures and tasks. In particular, on the state-of-the-art Lumina-DiMOO, it achieves over 4x acceleration of text-to-image generation while maintaining quality, significantly pushing the Pareto frontier of masked image generation. The code and model weights are available at https://github.com/Kaiwen-Zhu/MIGM-Shortcut.
Paper Structure (28 sections, 6 equations, 22 figures, 3 tables)

This paper contains 28 sections, 6 equations, 22 figures, 3 tables.

Figures (22)

  • Figure 1: Visualization of trajectory smoothness. A point in the trajectory is the feature averaged over all tokens in a step. Left: heatmap of pairwise cosine similarity; Right: t-SNE visualization.
  • Figure 2: PCA visualization of feature trajectories generated with the same prompt and initial random seed. (a) Using a MIGM, we first generate a trajectory (the dark one) and then change the random seed at intermediate steps to generate more samples (the light ones). The randomness in sampling tokens greatly affects the generation process. (b) In contrast, for continuous diffusion with ODE sampling, trajectories generated from the same starting point are always the same, without randomness at intermediate steps.
  • Figure 3: Observation of local Lipschitz behavior. Consider the map \ref{['eq:map']}. For neighboring points along trajectories, we compute the norms of the target and input differences. The ratio of the two norms concentrates around a moderate constant.
  • Figure 4: Inference workflow of MIGM-Shortcut. Colored blocks and solid lines represent activated computation, while gray blocks and dash lines represent suppressed computation. At a full step, the inference is the same as the vanilla procedure, taking $\boldsymbol{x}_{t_{i-1}}$ as input to compute $\boldsymbol{f}_{t_i}$ using the base model. At a shortcut step, the shortcut model takes $\boldsymbol{x}_{t_{i-1}}$ and $\boldsymbol{f}_{t_{i-1}}$ as input to compute $\boldsymbol{f}_{t_i}$, leaving the heavy base model skipped.
  • Figure 5: Quality-speed trade-off MaskGIT with and without shortcut. MaskGIT-Shortcut can reach lower FID and faster speed. # steps $N$ and the budget $B$ are marked near the points.
  • ...and 17 more figures