Table of Contents
Fetching ...

Aligning Latent Spaces with Flow Priors

Yizhuo Li, Yuying Ge, Yixiao Ge, Ying Shan, Ping Luo

Abstract

This paper presents a novel framework for aligning learnable latent spaces to arbitrary target distributions by leveraging flow-based generative models as priors. Our method first pretrains a flow model on the target features to capture the underlying distribution. This fixed flow model subsequently regularizes the latent space via an alignment loss, which reformulates the flow matching objective to treat the latents as optimization targets. We formally prove that minimizing this alignment loss establishes a computationally tractable surrogate objective for maximizing a variational lower bound on the log-likelihood of latents under the target distribution. Notably, the proposed method eliminates computationally expensive likelihood evaluations and avoids ODE solving during optimization. As a proof of concept, we demonstrate in a controlled setting that the alignment loss landscape closely approximates the negative log-likelihood of the target distribution. We further validate the effectiveness of our approach through large-scale image generation experiments on ImageNet with diverse target distributions, accompanied by detailed discussions and ablation studies. With both theoretical and empirical validation, our framework paves a new way for latent space alignment.

Aligning Latent Spaces with Flow Priors

Abstract

This paper presents a novel framework for aligning learnable latent spaces to arbitrary target distributions by leveraging flow-based generative models as priors. Our method first pretrains a flow model on the target features to capture the underlying distribution. This fixed flow model subsequently regularizes the latent space via an alignment loss, which reformulates the flow matching objective to treat the latents as optimization targets. We formally prove that minimizing this alignment loss establishes a computationally tractable surrogate objective for maximizing a variational lower bound on the log-likelihood of latents under the target distribution. Notably, the proposed method eliminates computationally expensive likelihood evaluations and avoids ODE solving during optimization. As a proof of concept, we demonstrate in a controlled setting that the alignment loss landscape closely approximates the negative log-likelihood of the target distribution. We further validate the effectiveness of our approach through large-scale image generation experiments on ImageNet with diverse target distributions, accompanied by detailed discussions and ablation studies. With both theoretical and empirical validation, our framework paves a new way for latent space alignment.

Paper Structure

This paper contains 45 sections, 5 theorems, 30 equations, 6 figures, 3 tables.

Key Result

Proposition 1

Let $\bm{v}_\theta: \mathbb{R}^{d_1} \times [0,1] \to \mathbb{R}^{d_1}$ be a given velocity field, and $p_{\mathrm{init}}$ be a base distribution. For $\bm{y} \in \mathbb{R}^{d_1}$, the log-likelihood $\log p_1^{\bm{v}_\theta}(\bm{y})$ is lower-bounded as: where $\lambda > 0$ is a constant, $\mathcal{L}_{\text{align}}(\bm{y}; \theta)$ is defined in Eq. eq:loss_y, and $C(\bm{y})$ is dependent on $

Figures (6)

  • Figure 1: (a) Conventional alignment works with only known priors (e.g., Gaussian or categorical) using KL or cross-entropy losses. (b) Our proposed method can align the latent distribution to arbitrary target distribution captured by a pre-trained flow model.
  • Figure 2: Intuitive illustration of latent space alignment via flow matching, best viewed in color. (a) A "good" $\bm{y}_{\text{good}}$ in $p_{\mathrm{data}}$ (green) aligns the straight path velocity (red solid arrow) with the pre-trained flow model's velocity $\bm{v}_\theta(\bm{z}_t,t)$ (overlapped and omitted), yielding low loss. (b) A "bad" $\bm{y}_{\text{bad}}$ outside $p_{\mathrm{data}}$ causes a mismatch between the path velocity and $\bm{v}_\theta(\bm{z}_t,t)$ (green solid arrow), resulting in high loss. Minimizing this loss steers $y_{\text{bad}}$ to $p_{\mathrm{data}}$ (blud dotted arrow).
  • Figure 3: Illustration with a Mixture of Gaussians distribution. (a) Aligned latent variables $\bm{y}$ (red triangles) concentrate in low negative log-likelihood (NLL) regions of $p_\text{data}$ (blue dots; heatmap shows $-\log p_\text{data}$). (b) Alignment loss $\mathcal{L}_{\text{align}}$ heatmap mirrors the NLL landscape of $p_\text{data}$, with $p_\text{data}$ samples in low-$\mathcal{L}_{\text{align}}$ areas. (c) $\mathcal{L}_{\text{align}}$ (blue solid) and $-\log p_\text{data}(\bm{y})$ (red dashed) decline simultaneously in training, showing $\mathcal{L}_{\text{align}}$ serves as a proxy for maximizing the log-likelihood of $\bm{y}$ under $p_\text{data}$.
  • Figure 4: Aligning autoencoders on ImageNet-1K with different target distributions. The alignment loss $\mathcal{L}_{\text{align}}$ (blue solid) and the $k$-NN distance $\log r_k(\bm{y})$ (red dashed) are proportional throughout the training. Confirming that $\mathcal{L}_{\text{align}}$ serves as a good proxy for the NLL of the latents under $p_{\text{data}}$.
  • Figure 5: Further illustrations of our method's performance on various 2D toy examples. Each row corresponds to a different target distribution $p_{\text{data}}$ (Grid of Gaussians, Two Moons, Concentric Rings, Spiral, and Swiss Roll). Left column (a,d,g,j,m): Optimized variables $\bm{y}$ (red triangles) and samples from $p_{\text{data}}$ (blue dots). The background heatmap visualizes the negative log-likelihood (NLL) $-\log p_{\text{data}}(\cdot)$, with $\bm{y}$ converging to low-NLL (high-density) regions. Middle column (b,e,h,k,n): The landscape of the alignment loss $\mathcal{L}_{\text{align}}$ (heatmap) with $p_{\text{data}}$ samples (blue dots). This landscape mirrors the NLL surface, and $p_{\text{data}}$ samples are concentrated in areas of low $\mathcal{L}_{\text{align}}$. Right column (c,f,i,l,o): Training curves for $\mathcal{L}_{\text{align}}(\bm{y}; \theta)$ (blue solid line) and NLL $-\log p_{\text{data}}(\bm{y})$ (red dashed line). Their strong positive correlation and concurrent decrease during optimization demonstrate that $\mathcal{L}_{\text{align}}$ effectively serves as a proxy for maximizing the log-likelihood of $\bm{y}$ under $p_{\text{data}}$.
  • ...and 1 more figures

Theorems & Definitions (11)

  • Proposition 1
  • proof
  • Proposition 1
  • Remark 1: Optimality of $\bm{v}_\theta$
  • proof
  • Lemma 1: Consistency of Variational Paths
  • proof
  • Theorem 1: Monotonic Behavior of the ELBO
  • proof
  • Proposition 1: Regularization by Neural Network Parameterization
  • ...and 1 more