Table of Contents
Fetching ...

MIST: Mutual Information Via Supervised Training

German Gritsai, Megan Richards, Maxime Méloux, Kyunghyun Cho, Maxime Peyrard

TL;DR

MIST reframes mutual information estimation as a supervised, data-driven task, learning a neural predictor that maps sets of paired samples to $I(X;Y)$ by training on a large meta-dataset with known MI. The method employs a SetTransformer++-based architecture with 2D attention to handle variable sample sizes and dimensions, and incorporates quantile regression to provide calibrated uncertainty without resorting to costly bootstrapping. Empirically, MIST and its quantile-augmented variant substantially outperform classical estimators across challenging regimes (low samples, high dimensions, diverse distributions), while offering orders-of-magnitude faster inference and seamless integration into larger pipelines. The framework further enables modality-agnostic training via invertible transformations and normalizing flows, enabling adaptation to arbitrary data domains, and provides an open-source library to train and evaluate meta-learned MI estimators.

Abstract

We propose a fully data-driven approach to designing mutual information (MI) estimators. Since any MI estimator is a function of the observed sample from two random variables, we parameterize this function with a neural network (MIST) and train it end-to-end to predict MI values. Training is performed on a large meta-dataset of 625,000 synthetic joint distributions with known ground-truth MI. To handle variable sample sizes and dimensions, we employ a two-dimensional attention scheme ensuring permutation invariance across input samples. To quantify uncertainty, we optimize a quantile regression loss, enabling the estimator to approximate the sampling distribution of MI rather than return a single point estimate. This research program departs from prior work by taking a fully empirical route, trading universal theoretical guarantees for flexibility and efficiency. Empirically, the learned estimators largely outperform classical baselines across sample sizes and dimensions, including on joint distributions unseen during training. The resulting quantile-based intervals are well-calibrated and more reliable than bootstrap-based confidence intervals, while inference is orders of magnitude faster than existing neural baselines. Beyond immediate empirical gains, this framework yields trainable, fully differentiable estimators that can be embedded into larger learning pipelines. Moreover, exploiting MI's invariance to invertible transformations, meta-datasets can be adapted to arbitrary data modalities via normalizing flows, enabling flexible training for diverse target meta-distributions.

MIST: Mutual Information Via Supervised Training

TL;DR

MIST reframes mutual information estimation as a supervised, data-driven task, learning a neural predictor that maps sets of paired samples to by training on a large meta-dataset with known MI. The method employs a SetTransformer++-based architecture with 2D attention to handle variable sample sizes and dimensions, and incorporates quantile regression to provide calibrated uncertainty without resorting to costly bootstrapping. Empirically, MIST and its quantile-augmented variant substantially outperform classical estimators across challenging regimes (low samples, high dimensions, diverse distributions), while offering orders-of-magnitude faster inference and seamless integration into larger pipelines. The framework further enables modality-agnostic training via invertible transformations and normalizing flows, enabling adaptation to arbitrary data domains, and provides an open-source library to train and evaluate meta-learned MI estimators.

Abstract

We propose a fully data-driven approach to designing mutual information (MI) estimators. Since any MI estimator is a function of the observed sample from two random variables, we parameterize this function with a neural network (MIST) and train it end-to-end to predict MI values. Training is performed on a large meta-dataset of 625,000 synthetic joint distributions with known ground-truth MI. To handle variable sample sizes and dimensions, we employ a two-dimensional attention scheme ensuring permutation invariance across input samples. To quantify uncertainty, we optimize a quantile regression loss, enabling the estimator to approximate the sampling distribution of MI rather than return a single point estimate. This research program departs from prior work by taking a fully empirical route, trading universal theoretical guarantees for flexibility and efficiency. Empirically, the learned estimators largely outperform classical baselines across sample sizes and dimensions, including on joint distributions unseen during training. The resulting quantile-based intervals are well-calibrated and more reliable than bootstrap-based confidence intervals, while inference is orders of magnitude faster than existing neural baselines. Beyond immediate empirical gains, this framework yields trainable, fully differentiable estimators that can be embedded into larger learning pipelines. Moreover, exploiting MI's invariance to invertible transformations, meta-datasets can be adapted to arbitrary data modalities via normalizing flows, enabling flexible training for diverse target meta-distributions.

Paper Structure

This paper contains 34 sections, 10 equations, 20 figures, 8 tables.

Figures (20)

  • Figure 1: We propose a fully data-driven, empirical approach to designing mutual information estimators. We design a large empirical meta-dataset composed of samples from a set of distributions with known MI (left), and train a SetTransformer-based model to predict MI directly from sets of samples (right).
  • Figure 2: Heatmaps providing a detailed analysis of the performance of the two proposed MI estimators and the KSG baseline. The data dimensions are grouped into three categories, and within each group, meta-datapoints are sorted by the number of samples. Shown are the average values of MSE, bias, variance, and confidence interval (CI cov.) width across $\mathcal{M}_{\text{test-extended}}$. CI coverage denotes the fraction of samples per group for which the true MI lies within the 95% bootstrap confidence interval.
  • Figure 3: Predicted MI as a function of the true MI on $\mathcal{M}_{\text{test-extended}}$. For each estimator, predictions are aggregated into bins of true MI values.
  • Figure 4: The MIST and MIST$_{\text{QR}}$ models scale substantially better to higher dimensions, requiring roughly half as many samples as KSG to achieve an MSE below the selected thresholds. The "+" marker indicates that over 500 samples ($n_{\text{row}}$) are required to obtain accurate estimates in all higher-dimensional settings.
  • Figure 5: Calibration of the proposed models on $\mathcal{M}_{\text{test,IMD}}$ (left) and $\mathcal{M}_{\text{test,OoMD}}$ (right), computed directly from quantiles for MIST$_{\text{QR}}$ and through bootstrap resampling for MIST and KSG. We report the mean absolute error in the legend.
  • ...and 15 more figures