Table of Contents
Fetching ...

Masked Diffusion as Self-supervised Representation Learner

Zixuan Pan, Jianxu Chen, Yiyu Shi

TL;DR

This work introduces Masked Diffusion Model (MDM), a self-supervised pre-training paradigm for semantic segmentation that replaces additive Gaussian noise with a masking corruption strategy and optimizes using SSIM to better align with downstream tasks. By freezing the pre-trained MDM as a representation generator and training a lightweight segmentation head, the method achieves state-of-the-art results on medical and natural segmentation benchmarks, particularly in few-shot settings. The key findings show that diffusion denoising is not strictly necessary for high-quality semantic representations, that masking-based pre-training can outperform MAE and DDPM, and that SSIM is a crucial loss for bridging reconstruction to segmentation. The approach has strong practical implications for label-efficient dense prediction, with potential extensions to broader architectures and data domains.

Abstract

Denoising diffusion probabilistic models have recently demonstrated state-of-the-art generative performance and have been used as strong pixel-level representation learners. This paper decomposes the interrelation between the generative capability and representation learning ability inherent in diffusion models. We present the masked diffusion model (MDM), a scalable self-supervised representation learner for semantic segmentation, substituting the conventional additive Gaussian noise of traditional diffusion with a masking mechanism. Our proposed approach convincingly surpasses prior benchmarks, demonstrating remarkable advancements in both medical and natural image semantic segmentation tasks, particularly in few-shot scenarios.

Masked Diffusion as Self-supervised Representation Learner

TL;DR

This work introduces Masked Diffusion Model (MDM), a self-supervised pre-training paradigm for semantic segmentation that replaces additive Gaussian noise with a masking corruption strategy and optimizes using SSIM to better align with downstream tasks. By freezing the pre-trained MDM as a representation generator and training a lightweight segmentation head, the method achieves state-of-the-art results on medical and natural segmentation benchmarks, particularly in few-shot settings. The key findings show that diffusion denoising is not strictly necessary for high-quality semantic representations, that masking-based pre-training can outperform MAE and DDPM, and that SSIM is a crucial loss for bridging reconstruction to segmentation. The approach has strong practical implications for label-efficient dense prediction, with potential extensions to broader architectures and data domains.

Abstract

Denoising diffusion probabilistic models have recently demonstrated state-of-the-art generative performance and have been used as strong pixel-level representation learners. This paper decomposes the interrelation between the generative capability and representation learning ability inherent in diffusion models. We present the masked diffusion model (MDM), a scalable self-supervised representation learner for semantic segmentation, substituting the conventional additive Gaussian noise of traditional diffusion with a masking mechanism. Our proposed approach convincingly surpasses prior benchmarks, demonstrating remarkable advancements in both medical and natural image semantic segmentation tasks, particularly in few-shot scenarios.
Paper Structure (20 sections, 8 equations, 9 figures, 7 tables)

This paper contains 20 sections, 8 equations, 9 figures, 7 tables.

Figures (9)

  • Figure 1: Dynamic masking process. Portions of data are probabilistically masked according to the timestep $t$ and subsequently reconstructed via a time-aware U-Net.
  • Figure 2: Overview of our proposed method. During pre-training, only the masked diffusion model (Encoder and Decoder) in Step 1 is trained. We partially mask the clean image based on a randomly sampled timestep $t$. The masked diffusion model then takes the masked image and reconstructs it. For the downstream segmentation task, the pre-trained model in Step 1 is frozen as a representation generator, and the segmentation network in Step 2 is trained with the representations extracted from Step 1.
  • Figure 3: Qualitative Visualization on GlaS test sets under full training labels setting (first 2 rows) and 10% training labels setting (last 2 rows) with the dice score (%) of each prediction.
  • Figure 4: Qualitative Visualization on MoNuSeg test sets under full training labels setting (first 2 rows) and 10% training labels setting (last 2 rows) with the dice score (%) of each prediction.
  • Figure 5: Qualitative Visualization on FFHQ-34 (first 2 rows) and CelebA-19 (last 2 rows).
  • ...and 4 more figures