Table of Contents
Fetching ...

A Note on Generalization in Variational Autoencoders: How Effective Is Synthetic Data & Overparameterization?

Tim Z. Xiao, Johannes Zenn, Robert Bamler

TL;DR

This work addresses encoder overfitting in variational autoencoders by examining two generalization pathways: leveraging synthetic data from pre-trained diffusion models (DMaaPx) and increasing parameters, particularly near the latent variable. DMaaPx replaces finite training data with unlimited, high-fidelity samples from a diffusion model trained on the dataset, yielding improved generalization, tighter amortization gaps, and greater robustness without altering the standard inference pipeline. The study also shows that expanding latent-adjacent parameters boosts performance, while excessive growth in other parts can degrade it, and provides evidence of double descent under certain parameter-growth trajectories. Collectively, the results offer practical guidance for mitigating encoder overfitting in VAEs and highlight the nuanced effects of model scaling in the presence of synthetic data.

Abstract

Variational autoencoders (VAEs) are deep probabilistic models that are used in scientific applications. Many works try to mitigate this problem from the probabilistic methods perspective by new inference techniques or training procedures. In this paper, we approach the problem instead from the deep learning perspective by investigating the effectiveness of using synthetic data and overparameterization for improving the generalization performance. Our motivation comes from (1) the recent discussion on whether the increasing amount of publicly accessible synthetic data will improve or hurt currently trained generative models; and (2) the modern deep learning insights that overparameterization improves generalization. Our investigation shows how both training on samples from a pre-trained diffusion model, and using more parameters at certain layers are able to effectively mitigate overfitting in VAEs, therefore improving their generalization, amortized inference, and robustness performance. Our study provides timely insights in the current era of synthetic data and scaling laws.

A Note on Generalization in Variational Autoencoders: How Effective Is Synthetic Data & Overparameterization?

TL;DR

This work addresses encoder overfitting in variational autoencoders by examining two generalization pathways: leveraging synthetic data from pre-trained diffusion models (DMaaPx) and increasing parameters, particularly near the latent variable. DMaaPx replaces finite training data with unlimited, high-fidelity samples from a diffusion model trained on the dataset, yielding improved generalization, tighter amortization gaps, and greater robustness without altering the standard inference pipeline. The study also shows that expanding latent-adjacent parameters boosts performance, while excessive growth in other parts can degrade it, and provides evidence of double descent under certain parameter-growth trajectories. Collectively, the results offer practical guidance for mitigating encoder overfitting in VAEs and highlight the nuanced effects of model scaling in the presence of synthetic data.

Abstract

Variational autoencoders (VAEs) are deep probabilistic models that are used in scientific applications. Many works try to mitigate this problem from the probabilistic methods perspective by new inference techniques or training procedures. In this paper, we approach the problem instead from the deep learning perspective by investigating the effectiveness of using synthetic data and overparameterization for improving the generalization performance. Our motivation comes from (1) the recent discussion on whether the increasing amount of publicly accessible synthetic data will improve or hurt currently trained generative models; and (2) the modern deep learning insights that overparameterization improves generalization. Our investigation shows how both training on samples from a pre-trained diffusion model, and using more parameters at certain layers are able to effectively mitigate overfitting in VAEs, therefore improving their generalization, amortized inference, and robustness performance. Our study provides timely insights in the current era of synthetic data and scaling laws.
Paper Structure (53 sections, 10 equations, 19 figures, 11 tables, 2 algorithms)

This paper contains 53 sections, 10 equations, 19 figures, 11 tables, 2 algorithms.

Figures (19)

  • Figure 1: Training distributions. $p_{\rm{aug}}({\bm{x}}')=\mathbb{E}_{{\bm{x}}\sim\mathcal{D}_{\mathrm{train}}}[p_{\rm{aug}}({\bm{x}}'\,|\,{\bm{x}})]$ only extrapolates from individual data points ${\bm{x}}\sim\mathcal{D}_{\mathrm{train}}$ and has density outside the support of $p_{\rm{data}}({\bm{x}})$ (e.g., when flipping the digit "2"). By contrast, $p_{\rm{DM}}({\bm{x}}')$ can interpolate between data ${\bm{x}}\sim\mathcal{D}_{\mathrm{train}}$.
  • Figure 2: Generalization (left), amortized inference (mid), and robustness (right) performance for VAEs trained with Eqs. (\ref{['eq:dtrain']})-(\ref{['eq:paug']}). Being slightly better than augmentations, DMaaPx consistently has the best performance on the test set and the smallest generalization, amortization, and robustness gap.
  • Figure 3: Density of $\log q_\phi$ evaluated on a line that linearly interpolates between two data samples from the test set of CIFAR-10. DMaaPx is smoother than normal training.
  • Figure 4: Left: Improvements in classification accuracy over normal training for various classifiers trained on the latent representations of CIFAR-10-C. Each column contains $19\times 3$ points (i.e., 19 corruptions, each VAE trained with 3 random seeds). Right: Generalization performance as a function of the amount $k$ of training data sampled from a diffusion model. Horizontal blue lines show baseline performance (VAE trained directly on $\mathcal{D}_{\mathrm{train}}$). All VAEs were trained for $1000$ effective epochs. $k \approx 10$ suffices.
  • Figure 5: Comparison between DMaaPx and Reverse Half Asleep (RHA) zhang2022generalization. Dotted vertical line shows epoch when decoder is frozen for RHA. "Normal Training + RHA" improves upon baseline but it is still worse than DMaaPx. Also, adding RHA on top of DMaaPx does not lead to improvements.
  • ...and 14 more figures