Table of Contents
Fetching ...

Prompt Reinjection: Alleviating Prompt Forgetting in Multimodal Diffusion Transformers

Yuxuan Yao, Yuxuan Chen, Hui Li, Kaihui Cheng, Qipeng Guo, Yuwei Sun, Zilong Dong, Jingdong Wang, Siyu Zhu

TL;DR

Prompt Reinjection addresses prompt forgetting in Multimodal Diffusion Transformers by reinjecting aligned shallow text features into deeper layers during inference. This training-free intervention uses distribution anchoring and orthogonal Procrustes alignment to stabilize cross-layer semantic transfer, substantially improving instruction following and spatial-numerical reasoning across multiple MMDiT backbones without sacrificing image quality. Layer-wise analyses (CKNNA and probes) reveal robust mitigation of semantic drift, while ablations underscore the importance of origin layer choice, full-layer reinjection, and alignment components. The approach offers a practical, scalable boost to prompt adherence in complex text–image generation tasks with broad applicability to real-world prompts and diverse prompt styles.

Abstract

Multimodal Diffusion Transformers (MMDiTs) for text-to-image generation maintain separate text and image branches, with bidirectional information flow between text tokens and visual latents throughout denoising. In this setting, we observe a prompt forgetting phenomenon: the semantics of the prompt representation in the text branch is progressively forgotten as depth increases. We further verify this effect on three representative MMDiTs--SD3, SD3.5, and FLUX.1 by probing linguistic attributes of the representations over the layers in the text branch. Motivated by these findings, we introduce a training-free approach, prompt reinjection, which reinjects prompt representations from early layers into later layers to alleviate this forgetting. Experiments on GenEval, DPG, and T2I-CompBench++ show consistent gains in instruction-following capability, along with improvements on metrics capturing preference, aesthetics, and overall text--image generation quality.

Prompt Reinjection: Alleviating Prompt Forgetting in Multimodal Diffusion Transformers

TL;DR

Prompt Reinjection addresses prompt forgetting in Multimodal Diffusion Transformers by reinjecting aligned shallow text features into deeper layers during inference. This training-free intervention uses distribution anchoring and orthogonal Procrustes alignment to stabilize cross-layer semantic transfer, substantially improving instruction following and spatial-numerical reasoning across multiple MMDiT backbones without sacrificing image quality. Layer-wise analyses (CKNNA and probes) reveal robust mitigation of semantic drift, while ablations underscore the importance of origin layer choice, full-layer reinjection, and alignment components. The approach offers a practical, scalable boost to prompt adherence in complex text–image generation tasks with broad applicability to real-world prompts and diverse prompt styles.

Abstract

Multimodal Diffusion Transformers (MMDiTs) for text-to-image generation maintain separate text and image branches, with bidirectional information flow between text tokens and visual latents throughout denoising. In this setting, we observe a prompt forgetting phenomenon: the semantics of the prompt representation in the text branch is progressively forgotten as depth increases. We further verify this effect on three representative MMDiTs--SD3, SD3.5, and FLUX.1 by probing linguistic attributes of the representations over the layers in the text branch. Motivated by these findings, we introduce a training-free approach, prompt reinjection, which reinjects prompt representations from early layers into later layers to alleviate this forgetting. Experiments on GenEval, DPG, and T2I-CompBench++ show consistent gains in instruction-following capability, along with improvements on metrics capturing preference, aesthetics, and overall text--image generation quality.
Paper Structure (25 sections, 14 equations, 11 figures, 12 tables)

This paper contains 25 sections, 14 equations, 11 figures, 12 tables.

Figures (11)

  • Figure 1: Prompt forgetting in MMDiTs and Prompt Reinjection. (a) We quantify prompt forgetting by probing token-level attribute recoverability. Accuracy drops monotonically with depth in SD3, SD3.5, and FLUX, indicating progressive loss of fine-grained prompt information in deeper text features. (b) We propose Prompt Reinjection: reinjecting aligned shallow-layer text features into later blocks during inference. (a) With Prompt Reinjection enabled, probing accuracy remains stable across depth, showing effective mitigation of forgetting. (c) Prompt Reinjection improves instruction following across multiple MMDiT variants, more consistently satisfying prompt constraints under diverse prompt styles.
  • Figure 2: Overall observation of per-layer text-token representations in SD3-medium and FLUX.1-Dev.
  • Figure 3: Probe accuracy reveals prompt forgetting in MMDiT text features. Each subplot reports per-category test accuracy when decoding token-level attributes from intermediate text representations at each layer for SD3-medium (left), SD3.5-large (middle), and FLUX.1-Dev (right).
  • Figure 4: Residual attribute injection results. During generation with prompt $A$, injecting shallow text features from prompt $B$ steers outputs toward the injected attribute, indicating that shallow residuals carry transferable semantics.
  • Figure 5: Qualitative comparison between each base model (SD3-medium, SD3.5-large, FLUX.1-Dev, and Qwen-Image) and its counterpart with Prompt Reinjection enabled. Bold text in the prompts highlights the constraints where our method improves text--image consistency over the base models.
  • ...and 6 more figures