Table of Contents
Fetching ...

CORAL: Disentangling Latent Representations in Long-Tailed Diffusion

Esther Rodriguez, Monica Welfert, Samuel McDowell, Nathan Stromberg, Julian Antolin Camarena, Lalitha Sankar

TL;DR

The paper investigates why diffusion models struggle with tail classes in long-tailed datasets and identifies entanglement in the U-Net bottleneck latent representations as a key cause. It introduces CORAL, which adds a bottleneck projection head and a supervised contrastive loss to explicitly align and separate latent class representations, with a time-aware weighting scheme. Across CIFAR-LT, CelebA-5, and ImageNet-LT, CORAL consistently improves tail-class diversity and fidelity, outperforming state-of-the-art methods like CBDM and T2H. The work demonstrates that latent-space disentanglement within the diffusion model can yield superior tail-class generation compared to ambient-space rebalancing, offering a principled approach for equitable long-tailed diffusion.

Abstract

Diffusion models have achieved impressive performance in generating high-quality and diverse synthetic data. However, their success typically assumes a class-balanced training distribution. In real-world settings, multi-class data often follow a long-tailed distribution, where standard diffusion models struggle -- producing low-diversity and lower-quality samples for tail classes. While this degradation is well-documented, its underlying cause remains poorly understood. In this work, we investigate the behavior of diffusion models trained on long-tailed datasets and identify a key issue: the latent representations (from the bottleneck layer of the U-Net) for tail class subspaces exhibit significant overlap with those of head classes, leading to feature borrowing and poor generation quality. Importantly, we show that this is not merely due to limited data per class, but that the relative class imbalance significantly contributes to this phenomenon. To address this, we propose COntrastive Regularization for Aligning Latents (CORAL), a contrastive latent alignment framework that leverages supervised contrastive losses to encourage well-separated latent class representations. Experiments demonstrate that CORAL significantly improves both the diversity and visual quality of samples generated for tail classes relative to state-of-the-art methods.

CORAL: Disentangling Latent Representations in Long-Tailed Diffusion

TL;DR

The paper investigates why diffusion models struggle with tail classes in long-tailed datasets and identifies entanglement in the U-Net bottleneck latent representations as a key cause. It introduces CORAL, which adds a bottleneck projection head and a supervised contrastive loss to explicitly align and separate latent class representations, with a time-aware weighting scheme. Across CIFAR-LT, CelebA-5, and ImageNet-LT, CORAL consistently improves tail-class diversity and fidelity, outperforming state-of-the-art methods like CBDM and T2H. The work demonstrates that latent-space disentanglement within the diffusion model can yield superior tail-class generation compared to ambient-space rebalancing, offering a principled approach for equitable long-tailed diffusion.

Abstract

Diffusion models have achieved impressive performance in generating high-quality and diverse synthetic data. However, their success typically assumes a class-balanced training distribution. In real-world settings, multi-class data often follow a long-tailed distribution, where standard diffusion models struggle -- producing low-diversity and lower-quality samples for tail classes. While this degradation is well-documented, its underlying cause remains poorly understood. In this work, we investigate the behavior of diffusion models trained on long-tailed datasets and identify a key issue: the latent representations (from the bottleneck layer of the U-Net) for tail class subspaces exhibit significant overlap with those of head classes, leading to feature borrowing and poor generation quality. Importantly, we show that this is not merely due to limited data per class, but that the relative class imbalance significantly contributes to this phenomenon. To address this, we propose COntrastive Regularization for Aligning Latents (CORAL), a contrastive latent alignment framework that leverages supervised contrastive losses to encourage well-separated latent class representations. Experiments demonstrate that CORAL significantly improves both the diversity and visual quality of samples generated for tail classes relative to state-of-the-art methods.

Paper Structure

This paper contains 47 sections, 8 equations, 13 figures, 5 tables, 1 algorithm.

Figures (13)

  • Figure 1: t-SNE visualizations of U-Net bottleneck features. The dataset visualized is CIFAR10-LT where the tail-to-head ratio is $0.01$, i.e., the head class (airplane) is 100 times more represented than the tail class (truck), with an exponential decay in-between. Real CIFAR10-LT samples are passed through models trained under different settings. Shown are (left) DDPM ho2021cfg trained on the original balanced CIFAR-10 dataset, (middle) DDPM trained on CIFAR10-LT with an imbalance ratio of 0.01, and (right) CORAL trained under the same imbalanced setting. In the balanced case, class representations are moderately separated, though some overlap remains. Under imbalance, DDPM exhibits substantial overlap between head and tail classes, an effect we refer to as representation entanglement, which degrades generation quality for tail classes. CORAL mitigates this effect by promoting class-wise separation in the latent space.
  • Figure 2: CORAL architecture and workflow on CelebA-5.(a) The five-class CelebA-5 training data is input to the U-Net architecture. (b) Denoising U-Net. The white inset shows an actual t-SNE visualization of the U-Net latent representations due to CORAL. (c) CORAL's addition to the standard DDPM architecture: a projection head MLP consisting of a single dense layer followed by normalization. (d) The output from the U-Net and the projection head are used to compute the corresponding diffusion and contrastive losses. (e) The contrastive loss is scaled by a time-dependent weighting function, $\lambda (t)$, and added to the standard diffusion loss to obtain the CORAL loss. (f) Samples are obtained from a trained CORAL model.
  • Figure 3: Per-class FID ($\downarrow$) for the CIFAR10-LT dataset with an imbalance factor $\rho = 0.001$
  • Figure 4: Comparison of generated samples from the class tulips (class 92) in CIFAR100-LT, $\rho=0.01$. CBDM (left), T2H (middle), and CORAL (right). CORAL shows increased diversity and fidelity relative to existing approaches.
  • Figure 5: Generated samples produced by CORAL on the CIFAR10-LT dataset with $\rho=0.01$.
  • ...and 8 more figures