Table of Contents
Fetching ...

Analysis of learning a flow-based generative model from limited sample complexity

Hugo Cui, Florent Krzakala, Eric Vanden-Eijnden, Lenka Zdeborová

TL;DR

This work provides a tight end-to-end asymptotic analysis of learning a flow-based generative model for a Gaussian mixture using a shallow two-layer autoencoder to parameterize the velocity field. It derives closed-form expressions for the learned velocity and its associated summary statistics, and shows that the generated mixture’s mean converges to the target mean at a rate of Θ(1/n), which is Bayes-optimal. The study reduces the high-dimensional transport to a small set of scalar ODEs, revealing how finite sample effects and architectural biases shape the generative process and highlighting when memorization occurs. Overall, it offers precise guidance on the interplay between learning from limited data and the quality of the resulting density sampling, including insights on when skip connections are beneficial.

Abstract

We study the problem of training a flow-based generative model, parametrized by a two-layer autoencoder, to sample from a high-dimensional Gaussian mixture. We provide a sharp end-to-end analysis of the problem. First, we provide a tight closed-form characterization of the learnt velocity field, when parametrized by a shallow denoising auto-encoder trained on a finite number $n$ of samples from the target distribution. Building on this analysis, we provide a sharp description of the corresponding generative flow, which pushes the base Gaussian density forward to an approximation of the target density. In particular, we provide closed-form formulae for the distance between the mean of the generated mixture and the mean of the target mixture, which we show decays as $Θ_n(\frac{1}{n})$. Finally, this rate is shown to be in fact Bayes-optimal.

Analysis of learning a flow-based generative model from limited sample complexity

TL;DR

This work provides a tight end-to-end asymptotic analysis of learning a flow-based generative model for a Gaussian mixture using a shallow two-layer autoencoder to parameterize the velocity field. It derives closed-form expressions for the learned velocity and its associated summary statistics, and shows that the generated mixture’s mean converges to the target mean at a rate of Θ(1/n), which is Bayes-optimal. The study reduces the high-dimensional transport to a small set of scalar ODEs, revealing how finite sample effects and architectural biases shape the generative process and highlighting when memorization occurs. Overall, it offers precise guidance on the interplay between learning from limited data and the quality of the resulting density sampling, including insights on when skip connections are beneficial.

Abstract

We study the problem of training a flow-based generative model, parametrized by a two-layer autoencoder, to sample from a high-dimensional Gaussian mixture. We provide a sharp end-to-end analysis of the problem. First, we provide a tight closed-form characterization of the learnt velocity field, when parametrized by a shallow denoising auto-encoder trained on a finite number of samples from the target distribution. Building on this analysis, we provide a sharp description of the corresponding generative flow, which pushes the base Gaussian density forward to an approximation of the target density. In particular, we provide closed-form formulae for the distance between the mean of the generated mixture and the mean of the target mixture, which we show decays as . Finally, this rate is shown to be in fact Bayes-optimal.
Paper Structure (32 sections, 3 theorems, 120 equations, 6 figures)

This paper contains 32 sections, 3 theorems, 120 equations, 6 figures.

Key Result

Corollary 3.3

(Mean squared error of the mean estimate) Let $\hat{\boldsymbol{\mu}}$ be the cluster mean of the density $\hat{\rho}_1$ generated by the (continuous) learnt flow eq:empirical_ODE. In the asymptotic limit described by Result res:w, the squared distance between $\hat{\boldsymbol{\mu}}$ and the true m with $M_1,Q^\xi_1,Q^\eta_1$ being the solutions of the ordinary differential equations eq:Xt_compo

Figures (6)

  • Figure 1: $n=4,\sigma=0.9,\lambda=0.1,\alpha(t)=1-t,\beta(t)=t,\varphi=\tanh.$ Solid lines: theoretical predictions of Result \ref{['res:w']}: squared norm of the DAE weight vector $\lVert \hat{\boldsymbol{w}}_t\lVert ^2$ (red), skip connection strength $\hat{c}_t$ (blue) cosine similarity between the weight vector $\hat{\boldsymbol{w}}_t$ and the target cluster mean $\boldsymbol{\mu}$, $\hat{\boldsymbol{w}}_t\angle \boldsymbol{\mu}\equiv \hat{\boldsymbol{w}}_t^\top \boldsymbol{\mu}/\lVert \boldsymbol{\mu}\lVert \lVert\hat{\boldsymbol{w}}_t\lVert$ (green), components $m_t,q^\xi_t$ of $\hat{\boldsymbol{w}}_t$ along the vectors $\boldsymbol{\mu}_{\mathrm{emp.}},\boldsymbol{\xi}$ (purple, pink, orange). Dots: numerical simulations in dimension $d=5\times 10^4$, corresponding to training the DAE \ref{['eq:DAE']} on the risk \ref{['eq:single_time_obj']} using the Pytorch implementation of full-batch Adam, with learning rate $0.0001$ over $4\times 10^4$ epochs and weight decay $\lambda=0.1$. The experimental points correspond to a single instance of the model.
  • Figure 2: In all three plots, $\lambda=0.1,\alpha(t)=1-t,\beta(t)=t,\varphi=\mathrm{sign}$. (left) $\sigma=1.5,n=8.$ Temporal evolution of the summary statistics $M_t,Q^\xi_t,Q_t, \boldsymbol{X}_t\angle\boldsymbol{\mu}$\ref{['eq:components_Xt']}. Solid lines correspond to the theoretical prediction of \ref{['eq:components_Xt']} in Result \ref{['res:stats']}, while dashed lines correspond to numerical simulations of the generative model, by discretizing the differential equation \ref{['eq:empirical_ODE']} with step size $\delta t=0.01$, and training a separate DAE for each time step using Adam with learning rate $0.01$ for $2000$ epochs. All experiments were conducted in dimension $d=5000$, and a single run is represented. (middle) $\sigma=2,n=16$. Projection of the distribution of $\boldsymbol{X}_t$\ref{['eq:empirical_ODE']} in $\mathrm{span}(\boldsymbol{\mu}_{\mathrm{emp.}},\boldsymbol{\xi})$, transported by the velocity field $\hat{\boldsymbol{b}}$\ref{['eq:bhat']} learnt from data. The point clouds correspond to numerical simulations. The dashed line corresponds to the theoretical prediction of the means of the cluster, as given by equation \ref{['eq:Xt_compo']} of Result \ref{['res:stats']}. The target Gaussian mixture $\rho_1$ is represented in red. The base zero-mean Gaussian density $\rho_0$ (dark blue) is split by the flow \ref{['eq:empirical_ODE']} into two clusters, which approach the target clusters (red) as time accrues . (right) $\sigma=2$. PCA visualization of the generated density $\hat{\rho}_1$, by training the generative model on $n$ samples, for $n\in\{4,8,16,32,64\}$. Point clouds represent numerical simulations of the generative model. Crosses represent the theoretical predictions of Result \ref{['res:stats']} for the means of the clusters of $\hat{\rho}_1$, as given by equation \ref{['eq:Xt_compo']} of Result \ref{['res:stats']} for $t=1$. As the number of training samples $n$ increases, the generated clusters of $\hat{\rho}_1$ approach the target clusters of $\rho_1$, represented in red.
  • Figure 3: $\alpha(t)=1-t,\beta(t)=t,\varphi=\mathrm{sign}$. Cosine asimilarity (left) and mean squared distance (right) between the mean $\hat{\boldsymbol{\mu}}$ of the generated mixture $\hat{\rho}_1$ and the mean $\boldsymbol{\mu}$ of the target density $\rho_1$, as a function of the number of training samples $n$, for various variances $\sigma$ of $\rho_1$. Solid lines represent the theoretical characterization of Corollary \ref{['cor:MSE']}. Crosses represent numerical simulations of the generative model, by discretizing the differential equation \ref{['eq:empirical_ODE']} with step size $\delta t=0.01$, and training a separate DAE for each time step using the Pytorch implementation of the full-batch Adam optimizer, with learning rate $0.04$ and weight decay $\lambda=0.1$ for $6000$ epochs. All experiments were conducted in dimension $d=5\times 10^4$, and a single run is represented. Dashed lines indicate the performance of the Bayes-optimal estimator $\hat{\boldsymbol{\mu}}^\star$, as theoretically characterized in Remark \ref{['rem:BO']}. Dots indicate the performance of the PCA estimator, which is found as in cui2023high to yield performances nearly identical to the Bayes-optimal estimator.
  • Figure 4: $\sigma=0.3,\lambda=0.1,\alpha(t)=\cos(\pi t/2),\beta(t)=\sin(\pi t/2)$. Solid lines: theoretical predictions for the MSE of Result \ref{['App:res:MSE']} (left) and the cosine similarity of Result \ref{['res:App:cosine']} (right). Different colors correspond to different number of samples $n$. Dots: numerical simulations, corresponding to training the DAE \ref{['eq:DAE']} on the risk \ref{['eq:single_time_obj']} using the Pytorch implementation of full-batch Adam, with learning rate $0.01$ over $2000$ epochs and weight decay $\lambda=0.1$. The experimental points correspond to a single instance of the model, and were collected in dimension $d=500$. In the left plot, the dashed line represent the oracle baseline \ref{['eq:App:singletime:mse_star']}.
  • Figure 5: $n=4,\sigma=0.9,\lambda=0.1,\alpha(t)=1-t,\beta(t)=t$. Imbalanced mixture with relative weights $\rho=0.24$ and $1-\rho=0.76$. Solid lines: theoretical predictions of Result \ref{['res:w']}: squared norm of the DAE weight vector $\lVert \hat{\boldsymbol{w}}_t\lVert ^2$ (red), skip connection strength $\hat{c}_t$ (blue) cosine similarity between the weight vector $\hat{\boldsymbol{w}}_t$ and the target cluster mean $\boldsymbol{\mu}$, $\hat{\boldsymbol{w}}_t\angle \boldsymbol{\mu}\equiv \hat{\boldsymbol{w}}_t^\top \boldsymbol{\mu}/\lVert \boldsymbol{\mu}\lVert \lVert\hat{\boldsymbol{w}}_t\lVert$ (green), components $m_t,q^\xi_t,q^\eta_t$ of $\hat{\boldsymbol{w}}_t$ along the vectors $\boldsymbol{\mu},\boldsymbol{\xi},\boldsymbol{\eta}$ (purple, pink, orange). Dots: numerical simulations in $d=5\times 10^4$, corresponding to training the DAE \ref{['eq:DAE']} on the risk \ref{['eq:single_time_obj']} uisng the Pytorch implementation of full-batch Adam, with learning rate $0.001$ over $20000$ epochs and weight decay $\lambda=0.1$. The experimental points correspond to a single instance of the model.
  • ...and 1 more figures

Theorems & Definitions (10)

  • Remark 3.2
  • Corollary 3.3
  • Remark 4.1
  • Remark B.1
  • Corollary B.2
  • Remark B.3
  • Corollary B.4
  • Remark D.3
  • Remark D.4
  • Definition F.1