Table of Contents
Fetching ...

Partition Generative Modeling: Masked Modeling Without Masks

Justin Deschenaux, Lan Tran, Caglar Gulcehre

TL;DR

Partition Generative Modeling (PGM) replaces the conventional masking in Masked Generative Models with a two-group partition and a GroupSwap mechanism, enabling fast, parallel sampling by restricting cross-group information flow. The Partition Transformer architecture supports partition-wise self-attention and cross-partition conditioning without full self-attention, allowing inference to focus on the unmasked (clean) tokens while still leveraging full-token supervision during training. Empirically, PGMs achieve 5–5.5x throughput gains on language tasks and up to 7.5x gains on ImageNet compared with strong MGM baselines, with only modest drops in sample quality and notable improvements when distillation is used. The approach remains compatible with distillation techniques (SDTT) and CFG, offering a scalable, flexible alternative to MGMs for high-speed generation and potential multimodal extensions.

Abstract

Masked generative models (MGMs) are widely used to capture complex data and enable faster generation than autoregressive models (AR) through parallel decoding. However, MGMs typically operate on fixed-length inputs, which can be inefficient: early in sampling, most tokens are masked and carry no information, leading to wasted computation. In contrast, AR models process only tokens generated previously, making early iterations faster. In this work, we introduce the Partition Generative Model (PGM), a novel approach that combines the strengths of AR and MGMs. Rather than masking, PGM partitions tokens into two groups and employs sparse attention to block information flow between them. Since there is no information flow between partitions, the model can process the previously-generated tokens only during sampling, while retaining the ability to generate tokens in parallel and in any order. On OpenWebText, PGMs offer at least $5\times$ improvements in sampling latency and throughput, while producing samples with superior Generative Perplexity, compared to Masked Diffusion Language Models. On ImageNet, PGMs achieve a $7.5\times$ higher throughput than MaskGIT, with only a slight increase in FID (5.54 vs. 5.35). With twice as many sampling steps, the FID reduces to 4.56 while while being $3.9\times$ faster than MaskGIT. Finally, PGMs integrate seamlessly with MGM distillation, providing further inference speedups.

Partition Generative Modeling: Masked Modeling Without Masks

TL;DR

Partition Generative Modeling (PGM) replaces the conventional masking in Masked Generative Models with a two-group partition and a GroupSwap mechanism, enabling fast, parallel sampling by restricting cross-group information flow. The Partition Transformer architecture supports partition-wise self-attention and cross-partition conditioning without full self-attention, allowing inference to focus on the unmasked (clean) tokens while still leveraging full-token supervision during training. Empirically, PGMs achieve 5–5.5x throughput gains on language tasks and up to 7.5x gains on ImageNet compared with strong MGM baselines, with only modest drops in sample quality and notable improvements when distillation is used. The approach remains compatible with distillation techniques (SDTT) and CFG, offering a scalable, flexible alternative to MGMs for high-speed generation and potential multimodal extensions.

Abstract

Masked generative models (MGMs) are widely used to capture complex data and enable faster generation than autoregressive models (AR) through parallel decoding. However, MGMs typically operate on fixed-length inputs, which can be inefficient: early in sampling, most tokens are masked and carry no information, leading to wasted computation. In contrast, AR models process only tokens generated previously, making early iterations faster. In this work, we introduce the Partition Generative Model (PGM), a novel approach that combines the strengths of AR and MGMs. Rather than masking, PGM partitions tokens into two groups and employs sparse attention to block information flow between them. Since there is no information flow between partitions, the model can process the previously-generated tokens only during sampling, while retaining the ability to generate tokens in parallel and in any order. On OpenWebText, PGMs offer at least improvements in sampling latency and throughput, while producing samples with superior Generative Perplexity, compared to Masked Diffusion Language Models. On ImageNet, PGMs achieve a higher throughput than MaskGIT, with only a slight increase in FID (5.54 vs. 5.35). With twice as many sampling steps, the FID reduces to 4.56 while while being faster than MaskGIT. Finally, PGMs integrate seamlessly with MGM distillation, providing further inference speedups.

Paper Structure

This paper contains 59 sections, 12 equations, 5 figures, 10 tables, 4 algorithms.

Figures (5)

  • Figure 1: (Left): On ImageNet, using the Halton sampler, PGM (ours), reaches similar FID as MaskGIT with a $7.5\times$ speedup. By sampling with twice as many steps, PGM reaches an FID of $4.56$ while remaining $3.9\times$ faster. (Right): On OpenWebText, PGM achieves a better Generative Perplexity with a $5.3\times$ higher sampling throughput compared to MDLM, at a context length of 1024. The improvements come from our proposed novel neural network architecture.
  • Figure 2: Masked Generative Modeling (MGM) vs. Partition Generative Modeling (PGM).Training: PGMs receive feedback at every position, while MGMs usually only apply loss to masked tokens. Inference: PGMs process only unmasked tokens, working with shorter sequences and predicting logits only for tokens to denoise. MGMs must process full-length sequences and compute logits at all positions. Important note: PGMs use a specialized architecture that ensures predictions for position $i$ never depend on the token at position $i$.
  • Figure 3: PGM-compatible transformer architecture. RoPE su2023roformerenhancedtransformerrotary is applied before every attention layer but not shown for clarity. (A) Decoder layer with cross-attention to the encoder output and no self-attention between tokens. (B) GroupSwap layer that exchanges information between positions in group 0 and group 1, enabling each group to make predictions based on tokens from the other group. (C) Encoder layer with sparse, group-wise self-attention.
  • Figure 4: After distillation, PGM (6 / 6, dim. 1024) with nucleus sampling remains significantly faster than MDLM, at matching entropy and Gen. PPL.
  • Figure 5: Training loss of MDLM, MDLM with Complementary Masking (\ref{['sec:exp-isolate-complementary-masking']}) and PGM. Complementary masking seems to introduce spikes in the loss, even though it did not cause the models to diverge.