Table of Contents
Fetching ...

Learning Discrete Concepts in Latent Hierarchical Models

Lingjing Kong, Guangyi Chen, Biwei Huang, Eric P. Xing, Yuejie Chi, Kun Zhang

TL;DR

This work formalizes the problem of learning discrete concepts from high-dimensional data by modeling concepts as discrete latent variables organized in a hierarchical causal graph. It develops identifiability guarantees for recovering bottom-level discrete concepts from continuous observations and for identifying the entire discrete latent hierarchy from the observed distribution, under mild, well-motivated conditions that generalize beyond trees and multi-level DAGs. The authors introduce a rank-based approach, including a discrete analog of non-negative rank and a minimal-graph operator, to recover the latent structure and prove identifiability up to permutation and graphical equivalence. They connect these theoretical insights to latent diffusion models, interpreting diffusion denoising objectives and diffusion steps as recovering concept embeddings at corresponding hierarchical levels, and validate the ideas with synthetic and real diffusion-model experiments. The results offer a principled lens on concept extraction and suggest practical implications for diffusion-based representation learning and causal-sparsity strategies to improve interpretability and controllability.

Abstract

Learning concepts from natural high-dimensional data (e.g., images) holds potential in building human-aligned and interpretable machine learning models. Despite its encouraging prospect, formalization and theoretical insights into this crucial task are still lacking. In this work, we formalize concepts as discrete latent causal variables that are related via a hierarchical causal model that encodes different abstraction levels of concepts embedded in high-dimensional data (e.g., a dog breed and its eye shapes in natural images). We formulate conditions to facilitate the identification of the proposed causal model, which reveals when learning such concepts from unsupervised data is possible. Our conditions permit complex causal hierarchical structures beyond latent trees and multi-level directed acyclic graphs in prior work and can handle high-dimensional, continuous observed variables, which is well-suited for unstructured data modalities such as images. We substantiate our theoretical claims with synthetic data experiments. Further, we discuss our theory's implications for understanding the underlying mechanisms of latent diffusion models and provide corresponding empirical evidence for our theoretical insights.

Learning Discrete Concepts in Latent Hierarchical Models

TL;DR

This work formalizes the problem of learning discrete concepts from high-dimensional data by modeling concepts as discrete latent variables organized in a hierarchical causal graph. It develops identifiability guarantees for recovering bottom-level discrete concepts from continuous observations and for identifying the entire discrete latent hierarchy from the observed distribution, under mild, well-motivated conditions that generalize beyond trees and multi-level DAGs. The authors introduce a rank-based approach, including a discrete analog of non-negative rank and a minimal-graph operator, to recover the latent structure and prove identifiability up to permutation and graphical equivalence. They connect these theoretical insights to latent diffusion models, interpreting diffusion denoising objectives and diffusion steps as recovering concept embeddings at corresponding hierarchical levels, and validate the ideas with synthetic and real diffusion-model experiments. The results offer a principled lens on concept extraction and suggest practical implications for diffusion-based representation learning and causal-sparsity strategies to improve interpretability and controllability.

Abstract

Learning concepts from natural high-dimensional data (e.g., images) holds potential in building human-aligned and interpretable machine learning models. Despite its encouraging prospect, formalization and theoretical insights into this crucial task are still lacking. In this work, we formalize concepts as discrete latent causal variables that are related via a hierarchical causal model that encodes different abstraction levels of concepts embedded in high-dimensional data (e.g., a dog breed and its eye shapes in natural images). We formulate conditions to facilitate the identification of the proposed causal model, which reveals when learning such concepts from unsupervised data is possible. Our conditions permit complex causal hierarchical structures beyond latent trees and multi-level directed acyclic graphs in prior work and can handle high-dimensional, continuous observed variables, which is well-suited for unstructured data modalities such as images. We substantiate our theoretical claims with synthetic data experiments. Further, we discuss our theory's implications for understanding the underlying mechanisms of latent diffusion models and provide corresponding empirical evidence for our theoretical insights.
Paper Structure (58 sections, 14 theorems, 5 equations, 15 figures, 2 tables, 5 algorithms)

This paper contains 58 sections, 14 theorems, 5 equations, 15 figures, 2 tables, 5 algorithms.

Key Result

Theorem 4.2

Under the generating process in Equation eq:discrete_generating and Condition cond:discrete_component_conditions-asmp:invertibility, the estimated discrete variable $\hat{\mathbf{d}}$ and the true discrete variable $\mathbf{d}$ are equivalent up to an invertible function, i.e., $\hat{\mathbf{d}} = h

Figures (15)

  • Figure 1: Latent hierarchical graphs. The dashed circle in (a) indicates that the continuous variable $\mathbf{c}$ can be viewed as an exogenous variable. Dashed edges in (b) indicate potential statistical dependence.
  • Figure 2: Graphical comparison. Tree Structures permit one undirected path between any two variables. Multi-level DAGs require partitioning variables into levels with edges only between adjacent levels. Our conditions allow multiple paths between variables across levels and include non-leaf observed variables.
  • Figure 3: Diffusion models estimate the latent hierarchical model. Different noise levels correspond to different concept levels. To avoid cluttering, we leave out vector quantization.
  • Figure 4: Recovering concepts and their relationships from LD.(a) The final recovered concept graph among concepts "dog", "tree", "eyes", "ears", "branch", and "leaf". (b) Identifying causal links through "interventions". For example, we compare two prompts that vary in "dog": "a dog with wide eyes and a wilting tree with short branches, in a cartoon style" and "a big dog with wide eyes and a wilting tree with short branches, in a cartoon style". We observe significant changes in "eyes" but not in "branch", indicating a causal link between "dog" and "eyes" but not between "dog" and "branch". (c) Identifying concept levels by the last effective diffusion step. For example, we use the base prompt "a tree with long branches, in a cartoon style" and prepend "dog" at steps 0, 5, and 15. Only injecting "dog" at step 0 works. Similarly, injecting "wide eyes" works at both steps 0 and 5, indicating that "dog" is a higher-level concept than "eyes".
  • Figure 5: Semantic latent space. We modify the diffusion model's representation (UNet encoder's output) along principal directions at steps $T$ and $0.6T$. Structure changes indicate the semantics of the representation and manipulation at the early time $T$ induces global shifts. See more examples in Figure \ref{['fig:semantic_unet_more']}.
  • ...and 10 more figures

Theorems & Definitions (34)

  • Definition 4.1: Component-wise Identifiability
  • Theorem 4.2: Discrete Component Identification
  • Definition 4.2: Non-negative Rank
  • Definition 4.2: Treks
  • Definition 4.2: T-separation
  • Theorem 4.3: Implication of Rank Information on Latent Discrete Graphs
  • Definition 4.4: Atomic Covers
  • Definition 4.4: Minimal-graph Operator huang2022latentdong2023versatile
  • Theorem 4.5: Discrete Hierarchical Identification
  • Theorem A1: Discrete Component Identification
  • ...and 24 more