Table of Contents
Fetching ...

Diffusion Model with Cross Attention as an Inductive Bias for Disentanglement

Tao Yang, Cuiling Lan, Yan Lu, Nanning zheng

TL;DR

This work investigates whether diffusion models can learn disentangled representations without explicit regularization by introducing EncDiff, a framework that encodes an image into concept tokens and conditions a latent diffusion model through cross-attention. It identifies two inductive biases—the information bottleneck in diffusion and cross-attention–driven interaction—that jointly promote factor disentanglement. Empirical results on Shapes3D, MPI3D, Cars3D, and CelebA show state-of-the-art disentanglement with high reconstruction quality, and ablations confirm the critical role of diffusion structure and cross-attention. The findings suggest diffusion-based inductive biases can power robust disentangled representations and motivate further exploration into diffusion-driven analysis and generation.

Abstract

Disentangled representation learning strives to extract the intrinsic factors within observed data. Factorizing these representations in an unsupervised manner is notably challenging and usually requires tailored loss functions or specific structural designs. In this paper, we introduce a new perspective and framework, demonstrating that diffusion models with cross-attention can serve as a powerful inductive bias to facilitate the learning of disentangled representations. We propose to encode an image to a set of concept tokens and treat them as the condition of the latent diffusion for image reconstruction, where cross-attention over the concept tokens is used to bridge the interaction between the encoder and diffusion. Without any additional regularization, this framework achieves superior disentanglement performance on the benchmark datasets, surpassing all previous methods with intricate designs. We have conducted comprehensive ablation studies and visualization analysis, shedding light on the functioning of this model. This is the first work to reveal the potent disentanglement capability of diffusion models with cross-attention, requiring no complex designs. We anticipate that our findings will inspire more investigation on exploring diffusion for disentangled representation learning towards more sophisticated data analysis and understanding.

Diffusion Model with Cross Attention as an Inductive Bias for Disentanglement

TL;DR

This work investigates whether diffusion models can learn disentangled representations without explicit regularization by introducing EncDiff, a framework that encodes an image into concept tokens and conditions a latent diffusion model through cross-attention. It identifies two inductive biases—the information bottleneck in diffusion and cross-attention–driven interaction—that jointly promote factor disentanglement. Empirical results on Shapes3D, MPI3D, Cars3D, and CelebA show state-of-the-art disentanglement with high reconstruction quality, and ablations confirm the critical role of diffusion structure and cross-attention. The findings suggest diffusion-based inductive biases can power robust disentangled representations and motivate further exploration into diffusion-driven analysis and generation.

Abstract

Disentangled representation learning strives to extract the intrinsic factors within observed data. Factorizing these representations in an unsupervised manner is notably challenging and usually requires tailored loss functions or specific structural designs. In this paper, we introduce a new perspective and framework, demonstrating that diffusion models with cross-attention can serve as a powerful inductive bias to facilitate the learning of disentangled representations. We propose to encode an image to a set of concept tokens and treat them as the condition of the latent diffusion for image reconstruction, where cross-attention over the concept tokens is used to bridge the interaction between the encoder and diffusion. Without any additional regularization, this framework achieves superior disentanglement performance on the benchmark datasets, surpassing all previous methods with intricate designs. We have conducted comprehensive ablation studies and visualization analysis, shedding light on the functioning of this model. This is the first work to reveal the potent disentanglement capability of diffusion models with cross-attention, requiring no complex designs. We anticipate that our findings will inspire more investigation on exploring diffusion for disentangled representation learning towards more sophisticated data analysis and understanding.
Paper Structure (24 sections, 1 theorem, 14 equations, 9 figures, 9 tables)

This paper contains 24 sections, 1 theorem, 14 equations, 9 figures, 9 tables.

Key Result

Theorem C.1

The Kullback-Leibler divergence is invariant under a differentiable mapping $f$, i.e . where $x = f(S)$ is a differentiable function between $x$ and $S$ and $p(x)$ and $q(x)$ are the probability density functions of the probability distributions $P$ and $Q$, respectively.

Figures (9)

  • Figure 1: Average attention map across all time steps in stable diffusion. We draw inspiration from the process of text-to-image generation using a diffusion model with cross-attention. Utilizing the highly 'disentangled' words as the condition for image generation, the cross-attention maps observed from the diffusion model exhibit a strong text semantic and spatial alignment, indicating the model is capable of incorporating each individual word into the generation process for a final semantic aligned generation. This leads us to question whether such a diffusion structure could be inductive to disentangled representation learning.
  • Figure 2: (a) Illustration of our framework EncDiff. We employ an image encoder $\tau_{\phi}$ to transform an image $I$ into a set of disentangled representations, which we treat them as the conditional input to the latent diffusion model with cross attention. Here cross attention bridges the interaction between the diffusion network and the image encoder. For simplicity, we only briefly show the diffusion model which consists of an encoder $E$, a denoising U-Net and a decoder $D$ that reconstructs the image from the latent $x_t$. (b) Information bottleneck reflected by KL divergence in reverse diffusion process. The KL divergence between the data distribution $q(x_{t-1}|x_t, x_0)$ and the Gaussian prior distribution $\mathcal{N}(0,\mathbf{I})$ under four different variance ($\beta$) schedules: cosine, linear, sqrt linear and sqrt. The results have been normalized by the number of dimensions.
  • Figure 3: Illustration of the encoder $\tau_{\phi}$, which transforms an image into a feature vector of dimension $N$, with each dimension (scalar) encoding a disentangled factor. We then use non-shared three-layer MLP layers to map each scalar into a vector (concept token). The concept tokens will be treated as the conditional input to the latent diffusion model with cross attention.
  • Figure 4: Comparisons of disentanglement performance and generation quality in terms of TAD and FID metrics (mean $\pm$ std) on real-world dataset CelebA. EncDiff achieves the state-of-the-art performance on both aspects compared to all baselines.
  • Figure 5: The qualitative results on Shapes3D. The source (SRC) images provide the representations of the generated image. The target (TRT) image provides the representation for swapping. Other images are generated by swapping the representation of the corresponding factor. For Shapes3D, the learned factors on Shapes3D are wall color (Wall), floor color (Floor), object color (Color), and object shape (Shape), orientation (Orien), scale. See Appendix \ref{['sec:visl']} for more visualizations.
  • ...and 4 more figures

Theorems & Definitions (1)

  • Theorem C.1