Table of Contents
Fetching ...

Explaining Datasets in Words: Statistical Models with Natural Language Parameters

Ruiqi Zhong, Heng Wang, Dan Klein, Jacob Steinhardt

TL;DR

The paper introduces predicate-conditioned distributions $p(x|\vec{\phi},\mathbf{w}) \propto e^{\mathbf{w}^\top \llbracket \vec{\phi} \rrbracket(x)}$ where natural language predicates $\phi$ provide interpretable features via denotation $\llbracket \phi \rrbracket(x) \in {0,1}$. It then develops a model-agnostic learning pipeline with a continuous relaxation $\tilde{\phi}$ and a discretization step via prompting language models, iterating to refine the predicates and weights. Three exemplar models—clustering, time series, and multiclass classification—are instantiated and evaluated on multiple text datasets, showing that relaxation and refinement improve performance and can match specialized explainable clustering methods. The framework also demonstrates broad open-ended applications in text and vision, enabling explanations for subareas, temporal dynamics, and cross-model comparisons, albeit with acknowledged computational costs and dependency on LLMs. Overall, this work provides a flexible, interpretable, language-grounded approach to analyzing complex datasets and extracting human-understandable patterns.

Abstract

To make sense of massive data, we often fit simplified models and then interpret the parameters; for example, we cluster the text embeddings and then interpret the mean parameters of each cluster. However, these parameters are often high-dimensional and hard to interpret. To make model parameters directly interpretable, we introduce a family of statistical models -- including clustering, time series, and classification models -- parameterized by natural language predicates. For example, a cluster of text about COVID could be parameterized by the predicate "discusses COVID". To learn these statistical models effectively, we develop a model-agnostic algorithm that optimizes continuous relaxations of predicate parameters with gradient descent and discretizes them by prompting language models (LMs). Finally, we apply our framework to a wide range of problems: taxonomizing user chat dialogues, characterizing how they evolve across time, finding categories where one language model is better than the other, clustering math problems based on subareas, and explaining visual features in memorable images. Our framework is highly versatile, applicable to both textual and visual domains, can be easily steered to focus on specific properties (e.g. subareas), and explains sophisticated concepts that classical methods (e.g. n-gram analysis) struggle to produce.

Explaining Datasets in Words: Statistical Models with Natural Language Parameters

TL;DR

The paper introduces predicate-conditioned distributions where natural language predicates provide interpretable features via denotation . It then develops a model-agnostic learning pipeline with a continuous relaxation and a discretization step via prompting language models, iterating to refine the predicates and weights. Three exemplar models—clustering, time series, and multiclass classification—are instantiated and evaluated on multiple text datasets, showing that relaxation and refinement improve performance and can match specialized explainable clustering methods. The framework also demonstrates broad open-ended applications in text and vision, enabling explanations for subareas, temporal dynamics, and cross-model comparisons, albeit with acknowledged computational costs and dependency on LLMs. Overall, this work provides a flexible, interpretable, language-grounded approach to analyzing complex datasets and extracting human-understandable patterns.

Abstract

To make sense of massive data, we often fit simplified models and then interpret the parameters; for example, we cluster the text embeddings and then interpret the mean parameters of each cluster. However, these parameters are often high-dimensional and hard to interpret. To make model parameters directly interpretable, we introduce a family of statistical models -- including clustering, time series, and classification models -- parameterized by natural language predicates. For example, a cluster of text about COVID could be parameterized by the predicate "discusses COVID". To learn these statistical models effectively, we develop a model-agnostic algorithm that optimizes continuous relaxations of predicate parameters with gradient descent and discretizes them by prompting language models (LMs). Finally, we apply our framework to a wide range of problems: taxonomizing user chat dialogues, characterizing how they evolve across time, finding categories where one language model is better than the other, clustering math problems based on subareas, and explaining visual features in memorable images. Our framework is highly versatile, applicable to both textual and visual domains, can be easily steered to focus on specific properties (e.g. subareas), and explains sophisticated concepts that classical methods (e.g. n-gram analysis) struggle to produce.
Paper Structure (35 sections, 15 equations, 10 figures, 5 tables, 1 algorithm)

This paper contains 35 sections, 15 equations, 10 figures, 5 tables, 1 algorithm.

Figures (10)

  • Figure 1: Our framework can use natural language predicates to parameterize a wide range of statistical models. Left. A clustering model that categorizes user queries. Middle. A time series model that characterizes how discussion changes across time. Right. A classification model that summarizes user traits. Once we define the model, we learn $\phi$ and $w$ based on $x$ (and $y$).
  • Figure 2: Left. The prompt to compute $\llbracket \phi \rrbracket (x)$. Right. The prompt to Discretize$\tilde{\phi}_{k}$, which generates a set of candidate predicates based on samples $x$ from $U$ and their scores $\text{cos}(e_{x}, \tilde{\phi}_{k})$.
  • Figure 3: Left. We generate a taxonomy with sophisticated explanations by recursively applying our clustering model. Right. We cluster with topic models and present the top words for each topic. Although some topics are plausibly related to certain applications, they are still ambiguous.
  • Figure 4: We analyze WildChat queries with our time series model. For each learned predicate, we plot how its frequency evolves and the 99% confidence interval of the average frequency (shaded).
  • Figure 5: The prompt template used to evaluate the surface form similarity between the predicted predicate $\hat{\phi}_{k}$ and the reference predicate $\phi^{*}_{k}$.
  • ...and 5 more figures