Table of Contents
Fetching ...

Optimal Stochastic Trace Estimation in Generative Modeling

Xinyang Liu, Hengrong Du, Wei Deng, Ruqi Zhang

TL;DR

The paper tackles the high-variance challenge of Hutchinson trace estimation in OT-guided generative modeling and diffusion-based objectives. It introduces Hutch++, a variance-reduced trace estimator that splits the trace into a exact large-eigenvalue component and a stochastic remainder, with an acceleration scheme that reuses the top eigenvectors via QR factorizations amortized over time. Theoretical analysis provides unbiasedness and variance bounds, plus complexity considerations, showing substantial variance reductions over the vanilla Hutchinson estimator. Empirically, Hutch++ improves training efficiency and generation quality across neural ODE-based models, Schrödinger-bridge diffusion, time-series tasks, and image generation, demonstrating scalable OT guarantees with high-dimensional data. The approach enables faster, more accurate transport maps in diverse generative settings and holds potential for broad applicability in OT-based learning and simulations.

Abstract

Hutchinson estimators are widely employed in training divergence-based likelihoods for diffusion models to ensure optimal transport (OT) properties. However, this estimator often suffers from high variance and scalability concerns. To address these challenges, we investigate Hutch++, an optimal stochastic trace estimator for generative models, designed to minimize training variance while maintaining transport optimality. Hutch++ is particularly effective for handling ill-conditioned matrices with large condition numbers, which commonly arise when high-dimensional data exhibits a low-dimensional structure. To mitigate the need for frequent and costly QR decompositions, we propose practical schemes that balance frequency and accuracy, backed by theoretical guarantees. Our analysis demonstrates that Hutch++ leads to generations of higher quality. Furthermore, this method exhibits effective variance reduction in various applications, including simulations, conditional time series forecasts, and image generation.

Optimal Stochastic Trace Estimation in Generative Modeling

TL;DR

The paper tackles the high-variance challenge of Hutchinson trace estimation in OT-guided generative modeling and diffusion-based objectives. It introduces Hutch++, a variance-reduced trace estimator that splits the trace into a exact large-eigenvalue component and a stochastic remainder, with an acceleration scheme that reuses the top eigenvectors via QR factorizations amortized over time. Theoretical analysis provides unbiasedness and variance bounds, plus complexity considerations, showing substantial variance reductions over the vanilla Hutchinson estimator. Empirically, Hutch++ improves training efficiency and generation quality across neural ODE-based models, Schrödinger-bridge diffusion, time-series tasks, and image generation, demonstrating scalable OT guarantees with high-dimensional data. The approach enables faster, more accurate transport maps in diverse generative settings and holds potential for broad applicability in OT-based learning and simulations.

Abstract

Hutchinson estimators are widely employed in training divergence-based likelihoods for diffusion models to ensure optimal transport (OT) properties. However, this estimator often suffers from high variance and scalability concerns. To address these challenges, we investigate Hutch++, an optimal stochastic trace estimator for generative models, designed to minimize training variance while maintaining transport optimality. Hutch++ is particularly effective for handling ill-conditioned matrices with large condition numbers, which commonly arise when high-dimensional data exhibits a low-dimensional structure. To mitigate the need for frequent and costly QR decompositions, we propose practical schemes that balance frequency and accuracy, backed by theoretical guarantees. Our analysis demonstrates that Hutch++ leads to generations of higher quality. Furthermore, this method exhibits effective variance reduction in various applications, including simulations, conditional time series forecasts, and image generation.

Paper Structure

This paper contains 47 sections, 8 theorems, 47 equations, 11 figures, 3 tables.

Key Result

Lemma 5.1

where $H_m(\mathbf{A})$ is the Hutchinson estimator with $m$ random vectors, $\mathbb{E}[\cdot]$ denotes expectation, and $\mathrm{Var}[\cdot]$ denotes variance.

Figures (11)

  • Figure 1: Visualization of density estimation obtained by (Top) FFJORD and (Bottom) FJORD++ during the training phase on 2spirals. The proposed variance-reduced model, FJORD++, improves both convergence and training stability, resulting in higher-quality estimated densities.
  • Figure 2: Comparison between the FFJORD++ and FFJORD on 2spirals with various scales. From left to right, each row represents the data with a default scale, as well as shapes stretched 2x and 4x along the X-axis, respectively. As the scale along the X-axis increases, FFJORD++ consistently exhibits superior convergence rates, further widening the performance gap with FFJORD.
  • Figure 3: (Top) Comparison of test loss during training on FFJORD++ trained with different $L_s$. (Left) Time cost for 100 iterations on models trained with different $L_s$.
  • Figure 4: (a)Imputation examples for PM25 for 1 (out of 36) dimensions. (b)Forecasting examples for Electricity for 1 (out of 370) dimensions
  • Figure 5: Visualization of density estimation obtained by (Top) FFJORD and (Bottom) FJORD++ during the training phase on four toy distributions in each sub-figure. The proposed variance reduced model, FJORD++, improves both convergence and training stability, resulting in higher-quality estimated densities.
  • ...and 6 more figures

Theorems & Definitions (8)

  • Lemma 5.1: Hutchinson89
  • Lemma 5.2: hutch_pp
  • Proposition 5.3: hutch_pp
  • Proposition 5.4
  • Proposition 5.5
  • Lemma 1.1
  • Theorem 1.2: halko2011finding
  • Lemma 1.3