Energy-Based Modelling for Discrete and Mixed Data via Heat Equations on Structured Spaces
Tobias Schröder, Zijing Ou, Yingzhen Li, Andrew B. Duncan
TL;DR
This work addresses the challenge of training energy-based models on discrete and mixed-state data by introducing Energy Discrepancy (ED), a contrasting loss that requires only energy evaluations on data and perturbed samples, thereby removing the need for MCMC. ED relies on discrete diffusion on structured spaces, implemented as a heat equation on graphs with rate matrices $R$ (e.g., uniform, cyclical, ordinal, and absorbing structures), and provides a MCMC-free route to maximum-likelihood-like training as $t$ grows. The authors extend ED to tabular data by combining geometric perturbations for discrete features with Gaussian noise for numeric features, and demonstrate strong performance across discrete density estimation, tabular data synthesis (including calibration tasks), and discrete image modelling, often with lower computational cost than CD-based methods. The approach yields robust generation and improved calibration on real-world datasets, suggesting practical impact for synthetic data, data imputation, and calibrated classification in tabular and mixed data domains. The work also offers theoretical guarantees linking ED to ML-estimation limits and provides scalable, parallelizable procedures via eigen-decompositions of structure-specific rate matrices. $$p_{\theta}(x) \propto \exp(-U_{\theta}(x))$$ and the ED loss $$\mathrm{ED}_q(p_{\mathrm{data}},U) = \mathbb{E}_{p_{\ m data}(x)}[U(x)] - \mathbb{E}_{p_{\rm data}(x)}\mathbb{E}_{q(y|x)}[U_q(y)],$$ with $U_q(y) = -\log \sum_{x'} q(y|x') e^{-U(x')}$$ are central to the methodology.
Abstract
Energy-based models (EBMs) offer a flexible framework for probabilistic modelling across various data domains. However, training EBMs on data in discrete or mixed state spaces poses significant challenges due to the lack of robust and fast sampling methods. In this work, we propose to train discrete EBMs with Energy Discrepancy, a loss function which only requires the evaluation of the energy function at data points and their perturbed counterparts, thus eliminating the need for Markov chain Monte Carlo. We introduce perturbations of the data distribution by simulating a diffusion process on the discrete state space endowed with a graph structure. This allows us to inform the choice of perturbation from the structure of the modelled discrete variable, while the continuous time parameter enables fine-grained control of the perturbation. Empirically, we demonstrate the efficacy of the proposed approaches in a wide range of applications, including the estimation of discrete densities with non-binary vocabulary and binary image modelling. Finally, we train EBMs on tabular data sets with applications in synthetic data generation and calibrated classification.
