Table of Contents
Fetching ...

Generating Samples to Probe Trained Models

Eren Mehmet Kıral, Nurşen Aydın, Ş. İlker Birbil

TL;DR

This work proposes a probabilistic framework for interrogating trained models by generating data samples that reflect specified probing questions. By formulating a data-space objective G and pairing it with a parameter-space objective F, the authors exploit Gibbs-based sampling (via MALA) and, optionally, latent-space encodings from VAEs to produce samples that reveal prediction-risky regions, parameter sensitivity, and model contrasts. The approach is demonstrated across diverse tasks and modalities (tabular, image) with experiments showing surface-level disagreements between models, near-boundary uncertainties, and data-manifold-aware counterfactuals, along with comparisons to existing counterfactual methods like DiCE. This framework provides a flexible, interpretable way to analyze model behavior beyond traditional accuracy metrics and offers practical insights for fairness, robustness, and explainability. The authors also release code and discuss future directions to incorporate domain-specific constraints and richer priors.

Abstract

There is a growing need for investigating how machine learning models operate. With this work, we aim to understand trained machine learning models by questioning their data preferences. We propose a mathematical framework that allows us to probe trained models and identify their preferred samples in various scenarios including prediction-risky, parameter-sensitive, or model-contrastive samples. To showcase our framework, we pose these queries to a range of models trained on a range of classification and regression tasks, and receive answers in the form of generated data.

Generating Samples to Probe Trained Models

TL;DR

This work proposes a probabilistic framework for interrogating trained models by generating data samples that reflect specified probing questions. By formulating a data-space objective G and pairing it with a parameter-space objective F, the authors exploit Gibbs-based sampling (via MALA) and, optionally, latent-space encodings from VAEs to produce samples that reveal prediction-risky regions, parameter sensitivity, and model contrasts. The approach is demonstrated across diverse tasks and modalities (tabular, image) with experiments showing surface-level disagreements between models, near-boundary uncertainties, and data-manifold-aware counterfactuals, along with comparisons to existing counterfactual methods like DiCE. This framework provides a flexible, interpretable way to analyze model behavior beyond traditional accuracy metrics and offers practical insights for fairness, robustness, and explainability. The authors also release code and discuss future directions to incorporate domain-specific constraints and richer priors.

Abstract

There is a growing need for investigating how machine learning models operate. With this work, we aim to understand trained machine learning models by questioning their data preferences. We propose a mathematical framework that allows us to probe trained models and identify their preferred samples in various scenarios including prediction-risky, parameter-sensitive, or model-contrastive samples. To showcase our framework, we pose these queries to a range of models trained on a range of classification and regression tasks, and receive answers in the form of generated data.

Paper Structure

This paper contains 16 sections, 30 equations, 16 figures, 2 tables.

Figures (16)

  • Figure 1: (Left) Overview of model probing by data generation. Samples from $p^*({\mathbf{\boldsymbol{x}}})$ answer the question posed by $G$. The vertical arrows (a) and (c) start with functions and lead to distributions on the same space by solving equation \ref{['eq:BLPforF']} and equation \ref{['eq:BLPforG']}. The diagonal arrow (b) starts with a distribution on the parameter space and obtains a loss function on the data space by integrating out ${\mathbf{\boldsymbol{\theta}}}$ dependence of a function on $\Theta \times \mathcal{X}$ against the distribution $q^*({\mathbf{\boldsymbol{\theta}}})$. (Right) The special case of the Linear Regression (LR) model with mean square error admits an analytic solution. The $G$ function is designed to find data points ${\mathbf{\boldsymbol{x}}}$ whose solutions under LR are close to a chosen prediction $y'$ and averaged over $q^*$. The distribution $p^*({\mathbf{\boldsymbol{x}}})$ is calculated to be a Gaussian distribution centered at a point $\hat{{\mathbf{\boldsymbol{f}}}}$, which is shifted from the mean of given data by a certain amount depending on the desired output value $y'$. Explicit forms of $\hat{{\mathbf{\boldsymbol{f}}}}, \Sigma$ and $\widehat{{\mathbf{\boldsymbol{\theta}}}}$ and their derivation can be seen in Appendix \ref{['app:LR']}.
  • Figure 2: Given a dataset of two concentric circles labeled red and blue, two Support Vector Machine (SVM) models are trained on the binary classification task with kernels chosen as Radial Basis Function (RBF) and cubic polynomial, respectively. The generated data points are green. In (a), we contrast the two SVM models, looking for samples for which their predictions differ, and discover that this is the case in a region near the origin lacking any training points. In (b) and (c), we inquire about data points that would be considered risky by the two models using RBF and cubic kernels, respectively. In (d) we design $G$ so that it generates data points which are classified with the opposite label of the orange point by the RBF-SVM without straying too far from it.
  • Figure 3: The distributions of three representative features in the generated samples. Here, XGBoost predicts "Bad" for RiskPerformance, while logistic regression predicts "Good".
  • Figure 4: (left) Using Langevin dynamics in the latent space, we obtain a sequence of latent vectors that, when passed through the decoder $\varphi$, correspond to a walk on the data manifold. In this image, the function $G$ is the sum of cross-entropy predictions of trained MLP and LeNet5 networks for the label '8' and for the data $\varphi(z)$. (right) Images in the first and second columns are generated to prefer a given label on an MLP model and another one on a CNN model. upper-left: CNN-'0' MLP-'1', upper-middle: CNN-'1' MLP-'7', lower-left: CNN-'0' MLP-'8', lower-middle: CNN-'2' MLP-'5'. On the third column, the upper image prefers the label '8' for the MLP model whilst being close to a data sample with label '3', and the same for the lower image for the CNN model.
  • Figure 5: Feature distributions in generated parameter-sensitive and prediction-risky samples.
  • ...and 11 more figures