Integrating Random Effects in Variational Autoencoders for Dimensionality Reduction of Correlated Data
Giora Simchoni, Saharon Rosset
TL;DR
VAEs assume IID observations, which limits performance on correlated datasets. LMMVAE addresses this by splitting the latent space into a fixed part $\mathbf{U}$ and a correlated random part $\mathbf{B}$, with a design matrix $\mathbf{Z}$ producing $\mathbf{Z}\mathbf{B}$ in the generative model, yielding $\mathbf{X} \approx f(\mathbf{U}) + \mathbf{Z}\mathbf{B} + \mathcal{E}$ and a modified ELBO that includes two KL terms. The framework generalizes to high-cardinality categorical data, longitudinal measurements, and spatial locations via appropriate covariance structures (e.g., matrix-normal, Phi, and kernel $\mathbf{K}$) and BLUP-style updates. Across extensive simulations and real datasets, LMMVAE achieves lower reconstruction error and NLL on unseen data and yields more informative latent representations for downstream tasks, outperforming several state-of-the-art alternatives. This work enables scalable, principled handling of structured correlation in large tabular and image datasets, enhancing representation learning and predictive performance in practical settings.
Abstract
Variational Autoencoders (VAE) are widely used for dimensionality reduction of large-scale tabular and image datasets, under the assumption of independence between data observations. In practice, however, datasets are often correlated, with typical sources of correlation including spatial, temporal and clustering structures. Inspired by the literature on linear mixed models (LMM), we propose LMMVAE -- a novel model which separates the classic VAE latent model into fixed and random parts. While the fixed part assumes the latent variables are independent as usual, the random part consists of latent variables which are correlated between similar clusters in the data such as nearby locations or successive measurements. The classic VAE architecture and loss are modified accordingly. LMMVAE is shown to improve squared reconstruction error and negative likelihood loss significantly on unseen data, with simulated as well as real datasets from various applications and correlation scenarios. It also shows improvement in the performance of downstream tasks such as supervised classification on the learned representations.
