Table of Contents
Fetching ...

Representation Alignment for Just Image Transformers is not Easier than You Think

Jaeyo Shin, Jiwook Kim, Hyunjung Shim

Abstract

Representation Alignment (REPA) has emerged as a simple way to accelerate Diffusion Transformers training in latent space. At the same time, pixel-space diffusion transformers such as Just image Transformers (JiT) have attracted growing attention because they remove a dependency on a pretrained tokenizer, and then avoid the reconstruction bottleneck of latent diffusion. This paper shows that the REPA can fail for JiT. REPA yields worse FID for JiT as training proceeds and collapses diversity on image subsets that are tightly clustered in the representation space of pretrained semantic encoder on ImageNet. We trace the failure to an information asymmetry: denoising occurs in the high dimensional image space, while the semantic target is strongly compressed, making direct regression a shortcut objective. We propose PixelREPA, which transforms the alignment target and constrains alignment with a Masked Transformer Adapter that combines a shallow transformer adapter with partial token masking. PixelREPA improves both training convergence and final quality. PixelREPA reduces FID from 3.66 to 3.17 for JiT-B$/16$ and improves Inception Score (IS) from 275.1 to 284.6 on ImageNet $256 \times 256$, while achieving $> 2\times$ faster convergence. Finally, PixelREPA-H$/16$ achieves FID$=1.81$ and IS$=317.2$. Our code is available at https://github.com/kaist-cvml/PixelREPA.

Representation Alignment for Just Image Transformers is not Easier than You Think

Abstract

Representation Alignment (REPA) has emerged as a simple way to accelerate Diffusion Transformers training in latent space. At the same time, pixel-space diffusion transformers such as Just image Transformers (JiT) have attracted growing attention because they remove a dependency on a pretrained tokenizer, and then avoid the reconstruction bottleneck of latent diffusion. This paper shows that the REPA can fail for JiT. REPA yields worse FID for JiT as training proceeds and collapses diversity on image subsets that are tightly clustered in the representation space of pretrained semantic encoder on ImageNet. We trace the failure to an information asymmetry: denoising occurs in the high dimensional image space, while the semantic target is strongly compressed, making direct regression a shortcut objective. We propose PixelREPA, which transforms the alignment target and constrains alignment with a Masked Transformer Adapter that combines a shallow transformer adapter with partial token masking. PixelREPA improves both training convergence and final quality. PixelREPA reduces FID from 3.66 to 3.17 for JiT-B and improves Inception Score (IS) from 275.1 to 284.6 on ImageNet , while achieving faster convergence. Finally, PixelREPA-H achieves FID and IS. Our code is available at https://github.com/kaist-cvml/PixelREPA.
Paper Structure (25 sections, 6 equations, 12 figures, 4 tables)

This paper contains 25 sections, 6 equations, 12 figures, 4 tables.

Figures (12)

  • Figure 1: REPA degrades JiT performance. As training progresses, JiT$+$REPA yields higher FID heusel2017gans ($\downarrow$) than vanilla JiT, indicating that REPA hinders pixel space diffusion training. PixelREPA prevents overfitting to the external semantic feature target, which accelerates convergence in JiT training. Remarkably, PixelREPA achieves $>2 \times$ faster convergence than the vanilla JiT. All evaluated models utilize JiT-B$/16$.
  • Figure 2: Overall Framework of PixelREPA. PixelREPA masks a subset of tokens in an intermediate diffusion feature map. The full token sequence, with only a subset masked, are then transformed by a shallow Transformer adapter and aligned to features from a frozen pretrained semantic encoder. This transforms the alignment target and reduces overfitting to the external semantic representation.
  • Figure 3: REPA accelerates pixel diffusion training at low resolution, whereas it degrades training at high resolution. This figure illustrates FID scores across different resolutions comparing JiT and JiT$+$REPA. Results show (a) ImageNet $32 \times 32$ and (b) ImageNet $256 \times 256$ with varying training epochs.
  • Figure 4: Visualization of semantic representation distribution by class with t-SNE van2008visualizing. For each of classes, we compute a centroid in the semantic feature space based on feature similarity. We mark the 100 samples most similar to the centroid as red dots and the 100 samples least similar to the centroid as blue dots.
  • Figure 5: REPA degrades generation diversity compared to the vanilla JiT on the most similar 100 samples for each class. This figure shows FID scores across different training data selection strategies. We compute FID across randomly selected 100 classes using 100 samples per class. Vanilla JiT achieves lower FID on the Most Similar 100 subset, whereas JiT$+$REPA achieves lower FID on the Least Similar 100 subset. Ours shows the best FID on both settings.
  • ...and 7 more figures