Table of Contents
Fetching ...

A Pytorch Reproduction of Masked Generative Image Transformer

Victor Besnier, Mickael Chen

TL;DR

Problem: reproducing masked generative image models like MaskGIT has been hindered by limited public code and weights. Approach: this work provides a PyTorch reproduction using VQGAN tokenization and a bidirectional transformer to unmask tokens, trained on ImageNet 256x256 and 512x512. Findings: the replication achieves FID close to the original (≈7.3 at 512x512; 6.8 at 256x256; 7.26 with minor tweaks) and confirms 64x faster sampling compared to autoregressive methods; ablations show benefits from Gumbel noise and arccos scheduling. Significance: releasing code and pretrained weights accelerates research and reproducibility of Masked Generative Models.

Abstract

In this technical report, we present a reproduction of MaskGIT: Masked Generative Image Transformer, using PyTorch. The approach involves leveraging a masked bidirectional transformer architecture, enabling image generation with only few steps (8~16 steps) for 512 x 512 resolution images, i.e., ~64x faster than an auto-regressive approach. Through rigorous experimentation and optimization, we achieved results that closely align with the findings presented in the original paper. We match the reported FID of 7.32 with our replication and obtain 7.59 with similar hyperparameters on ImageNet at resolution 512 x 512. Moreover, we improve over the official implementation with some minor hyperparameter tweaking, achieving FID of 7.26. At the lower resolution of 256 x 256 pixels, our reimplementation scores 6.80, in comparison to the original paper's 6.18. To promote further research on Masked Generative Models and facilitate their reproducibility, we released our code and pre-trained weights openly at https://github.com/valeoai/MaskGIT-pytorch/

A Pytorch Reproduction of Masked Generative Image Transformer

TL;DR

Problem: reproducing masked generative image models like MaskGIT has been hindered by limited public code and weights. Approach: this work provides a PyTorch reproduction using VQGAN tokenization and a bidirectional transformer to unmask tokens, trained on ImageNet 256x256 and 512x512. Findings: the replication achieves FID close to the original (≈7.3 at 512x512; 6.8 at 256x256; 7.26 with minor tweaks) and confirms 64x faster sampling compared to autoregressive methods; ablations show benefits from Gumbel noise and arccos scheduling. Significance: releasing code and pretrained weights accelerates research and reproducibility of Masked Generative Models.

Abstract

In this technical report, we present a reproduction of MaskGIT: Masked Generative Image Transformer, using PyTorch. The approach involves leveraging a masked bidirectional transformer architecture, enabling image generation with only few steps (8~16 steps) for 512 x 512 resolution images, i.e., ~64x faster than an auto-regressive approach. Through rigorous experimentation and optimization, we achieved results that closely align with the findings presented in the original paper. We match the reported FID of 7.32 with our replication and obtain 7.59 with similar hyperparameters on ImageNet at resolution 512 x 512. Moreover, we improve over the official implementation with some minor hyperparameter tweaking, achieving FID of 7.26. At the lower resolution of 256 x 256 pixels, our reimplementation scores 6.80, in comparison to the original paper's 6.18. To promote further research on Masked Generative Models and facilitate their reproducibility, we released our code and pre-trained weights openly at https://github.com/valeoai/MaskGIT-pytorch/
Paper Structure (6 sections, 5 figures, 4 tables)

This paper contains 6 sections, 5 figures, 4 tables.

Figures (5)

  • Figure 1: Examples generated on ImageNet at $512\times 512$ demonstrate the effectiveness of our reproduction. Hyperparameters for this set of examples are Gumbel temperature set to 7, Softmax temperature set to 1.3, CFG set to 9, scheduler set to arccos, and scheduler step set to 32
  • Figure 2: Hyperparameters search: ablation on crucial parameters to control sampling quality and diversity: number of training epochs, Gumbel noise, number of steps and the cfg weight.
  • Figure 3: Diversity comparison between the official paper (top row) and our reproduction (bottom row). Without cherry pinking on our side, our methods exhibit a little bit less diversity but a higher quality
  • Figure 4: Intermediate images generated at $256\times 256$ resolution. The first row showcases the binary mask, the second row exhibits the dog generated in association with the binary mask, and the last row presents another example of a sailboat.
  • Figure 5: Image inpainting:A rooster (ImageNet 007) and a zebra (ImageNet 340), generated and inpainted in a Cityscapes image.