Table of Contents
Fetching ...

What Drives Compositional Generalization in Visual Generative Models?

Karim Farid, Rajat Sahay, Yumna Ali Alnaggar, Simon Schrodi, Volker Fischer, Cordelia Schmid, Thomas Brox

TL;DR

This work investigates what drives compositional generalization in visual generative models by conducting controlled, cross-architecture experiments that compare continuous versus discrete output spaces and varying conditioning completeness. The authors show that models trained to represent continuous distributions exhibit stronger level-2 compositional generalization, while full, non-quantized conditioning is crucial for reliable recombination of factors. They further introduce a JEPA-based auxiliary objective to improve discrete models (e.g., MaskGIT), which yields more disentangled representations and stronger compositionality, supported by mechanistic analyses of attention heads and circuits. The results extend to real-world video domains and suggest that continuous latent or intermediate representations, together with complete conditioning, are key for developing generative models with reliable, systematic generalization across modalities and tasks.

Abstract

Compositional generalization, the ability to generate novel combinations of known concepts, is a key ingredient for visual generative models. Yet, not all mechanisms that enable or inhibit it are fully understood. In this work, we conduct a systematic study of how various design choices influence compositional generalization in image and video generation in a positive or negative way. Through controlled experiments, we identify two key factors: (i) whether the training objective operates on a discrete or continuous distribution, and (ii) to what extent conditioning provides information about the constituent concepts during training. Building on these insights, we show that relaxing the MaskGIT discrete loss with an auxiliary continuous JEPA-based objective can improve compositional performance in discrete models like MaskGIT.

What Drives Compositional Generalization in Visual Generative Models?

TL;DR

This work investigates what drives compositional generalization in visual generative models by conducting controlled, cross-architecture experiments that compare continuous versus discrete output spaces and varying conditioning completeness. The authors show that models trained to represent continuous distributions exhibit stronger level-2 compositional generalization, while full, non-quantized conditioning is crucial for reliable recombination of factors. They further introduce a JEPA-based auxiliary objective to improve discrete models (e.g., MaskGIT), which yields more disentangled representations and stronger compositionality, supported by mechanistic analyses of attention heads and circuits. The results extend to real-world video domains and suggest that continuous latent or intermediate representations, together with complete conditioning, are key for developing generative models with reliable, systematic generalization across modalities and tasks.

Abstract

Compositional generalization, the ability to generate novel combinations of known concepts, is a key ingredient for visual generative models. Yet, not all mechanisms that enable or inhibit it are fully understood. In this work, we conduct a systematic study of how various design choices influence compositional generalization in image and video generation in a positive or negative way. Through controlled experiments, we identify two key factors: (i) whether the training objective operates on a discrete or continuous distribution, and (ii) to what extent conditioning provides information about the constituent concepts during training. Building on these insights, we show that relaxing the MaskGIT discrete loss with an auxiliary continuous JEPA-based objective can improve compositional performance in discrete models like MaskGIT.

Paper Structure

This paper contains 56 sections, 9 equations, 25 figures, 6 tables.

Figures (25)

  • Figure 1: Compositional Generalization Analysis. We evaluate how generative models (MaskGIT, DiT) generalize to novel compositions of three binary factors on CelebA: gender, hair color, and smile. Models are trained on four combinations (blue) and evaluated on two sets of novel compositions (pink: Level-1 (one-factor change), red: Level-2 (two-factor change)). While MaskGIT (2nd row) shows poor compositional generalization, DiT (1st row) exhibits better compositional generalization. We also show that we can improve MaskGIT's compositional generalization abilities by augmenting its training objective with a JEPA-based training objective (3rd row).
  • Figure 2: DiT exhibits compositional generalization regardless of the type of tokenizer used. While the training dynamics differ, DiT shows compositional generalization at end of training. The blue, pink, and red curves show linear probe accuracies for the training data, level-1 compositions, or level-2 compositions, respectively. Consistent results are observed for MAR (\ref{['fig:mar_tokenizer_results']}) and across video datasets (see \ref{['fig:dit_clevrer_tokenizer_results']}).
  • Figure 3: Compositional generalization performance on Shapes2D across different model architectures. Models that learn continuous distributions (DiT, MAR, and GIVT) consistently show better level-2 compositions than MaskGIT, with the decisive shift in performance occurring at the categorical-to-continuous intervention. The blue, pink, and red curves denote training, level-1, and level-2 compositions, respectively. Consistent results are observed for CLEVRER-Kubric (\ref{['fig:latent_space_results_clevrer']}).
  • Figure 4: Comparison of conditioning information levels and their impact on compositional generalization in DiT on Shapes2D.(a) Continuous (full-information) conditioning leads to uniform convergence across all compositions. (b) Label dropout conditioning leads to inconsistent generalization; several unseen compositions fail completely. (c) Discrete (quantized) conditioning leads to partial generalization, with some failing samples. (d) Discrete (quantized) conditioning with dropout, the most severe loss of information, leads to the most failure. Shaded areas indicate standard deviation across three different seeds. We provide additional results in \ref{['appendix:conditioning_results']}. The blue curves show performance on training data, pink curves depict level-1 compositions, and red curve denotes level-2 compositions.
  • Figure 5: An overview of MaskGIT combined with the JEPA-based training objective. We apply the JEPA loss at specific layers ($l$) on an intermediate masked token representation in the transformer $(H_{C}^{(l)})$ and train a lightweight predictor to reconstruct target states $(H_{T}^{(l)})$ using MSE as an error metric and a stop-gradient signal to avoid representation collapse.
  • ...and 20 more figures