Table of Contents
Fetching ...

[MASK] is All You Need

Vincent Tao Hu, Björn Ommer

TL;DR

This work introduces Discrete Interpolants, a unified discrete-state framework that bridges Masked Generative Models and diffusion-style models for vision. By formulating a stochastic interpolant over discrete tokens and allowing both explicit and implicit timesteps, the approach enables training once to model the joint distribution and supports versatile sampling, including MGM-style and conditional sampling. The authors demonstrate state-of-the-art or competitive results on MS COCO, ImageNet 256, and video data (FaceForensics), and show that segmentation tasks can be reframed as unmasking within the same framework. They also provide extensive ablations and visualizations, validating the design choices around masking, weighting, scheduler, and guidance, with strong evidence of scalability to multi-modal and multi-task settings. Overall, Discrete Interpolants offer a principled pathway to unify generative and discriminative tasks in discrete vision representations with practical gains in fidelity and flexibility.

Abstract

In generative models, two paradigms have gained attraction in various applications: next-set prediction-based Masked Generative Models and next-noise prediction-based Non-Autoregressive Models, e.g., Diffusion Models. In this work, we propose using discrete-state models to connect them and explore their scalability in the vision domain. First, we conduct a step-by-step analysis in a unified design space across two types of models including timestep-independence, noise schedule, temperature, guidance strength, etc in a scalable manner. Second, we re-cast typical discriminative tasks, e.g., image segmentation, as an unmasking process from [MASK] tokens on a discrete-state model. This enables us to perform various sampling processes, including flexible conditional sampling by only training once to model the joint distribution. All aforementioned explorations lead to our framework named Discrete Interpolants, which enables us to achieve state-of-the-art or competitive performance compared to previous discrete-state based methods in various benchmarks, like ImageNet256, MS COCO, and video dataset FaceForensics. In summary, by leveraging [MASK] in discrete-state models, we can bridge Masked Generative and Non-autoregressive Diffusion models, as well as generative and discriminative tasks.

[MASK] is All You Need

TL;DR

This work introduces Discrete Interpolants, a unified discrete-state framework that bridges Masked Generative Models and diffusion-style models for vision. By formulating a stochastic interpolant over discrete tokens and allowing both explicit and implicit timesteps, the approach enables training once to model the joint distribution and supports versatile sampling, including MGM-style and conditional sampling. The authors demonstrate state-of-the-art or competitive results on MS COCO, ImageNet 256, and video data (FaceForensics), and show that segmentation tasks can be reframed as unmasking within the same framework. They also provide extensive ablations and visualizations, validating the design choices around masking, weighting, scheduler, and guidance, with strong evidence of scalability to multi-modal and multi-task settings. Overall, Discrete Interpolants offer a principled pathway to unify generative and discriminative tasks in discrete vision representations with practical gains in fidelity and flexibility.

Abstract

In generative models, two paradigms have gained attraction in various applications: next-set prediction-based Masked Generative Models and next-noise prediction-based Non-Autoregressive Models, e.g., Diffusion Models. In this work, we propose using discrete-state models to connect them and explore their scalability in the vision domain. First, we conduct a step-by-step analysis in a unified design space across two types of models including timestep-independence, noise schedule, temperature, guidance strength, etc in a scalable manner. Second, we re-cast typical discriminative tasks, e.g., image segmentation, as an unmasking process from [MASK] tokens on a discrete-state model. This enables us to perform various sampling processes, including flexible conditional sampling by only training once to model the joint distribution. All aforementioned explorations lead to our framework named Discrete Interpolants, which enables us to achieve state-of-the-art or competitive performance compared to previous discrete-state based methods in various benchmarks, like ImageNet256, MS COCO, and video dataset FaceForensics. In summary, by leveraging [MASK] in discrete-state models, we can bridge Masked Generative and Non-autoregressive Diffusion models, as well as generative and discriminative tasks.

Paper Structure

This paper contains 53 sections, 10 equations, 19 figures, 5 tables, 1 algorithm.

Figures (19)

  • Figure 1: Discrete Interpolants for training and sampling: During training, we first obtain discrete interpolants $x_t$ from $x_0$ and $x_1$ following a specific scheduler $\kappa_t$. We then train a model with the cross-entropy loss to predict the original data with $\tilde{p}_{1|t}(\textbf{x}_t,\mathbbm{t};\theta)$, where $\mathbbm{t}$ indicates that our timestep $t$ is optional, leading to both Explicit Timestep and Implicit Timestep Models. For sampling, we begin with a fully masked $x_0$ and progressively unmask to reach the final fully unmasked $x_1$. Lastly, we decode the indices back to pixel space.
  • Figure 2: Churning sampling by argmax can 1), alleviate the misalignment between schedulers. 2), boost sampling performance in low-NFE. First, we visualize the progressive chain of changes when sampling with a scheduler $\kappa_t$ that differs from the linear scheduler used during training. Our sampling process uses 50 steps and a CFG scale of 3. Second, we demonstrate that applying the argmax operation to logits can significantly reduce the occurrence of remaining [MASK] tokens after sampling.
  • Figure 3: Ablation about Explicit Timestep Model (ETM), Implicit Timestep Model (ITM), and Masked Generative Model(MGM) style Sampling on ImageNet 256 dataset with FID-5k. All models are trained with linear schedulers by default.
  • Figure 4: Chain visualization for ImageNet 256 with 100 timesteps with argmax applied.
  • Figure 5: Non cherry-picked visualization of MS COCO dataset. CFG=4.5, FID-50k=5.8.
  • ...and 14 more figures