Table of Contents
Fetching ...

CoLiDR: Concept Learning using Aggregated Disentangled Representations

Sanchit Sinha, Guangzhi Xiong, Aidong Zhang

TL;DR

CoLiDR addresses the need to connect human-understandable concepts with underlying generative factors in data by unifying disentangled representation learning with concept-based explanations. It introduces a three-module framework: a disentangled representations learner ($eta$-VAE or VAE), an Aggregation/Decomposition module that forms and inverts concepts from latent factors, and a task predictor that uses concepts to predict labels, with a consistency loss and sparsity regularization to maintain interpretability. The approach supports both annotated and unannotated concepts and demonstrates competitive task performance while enabling test-time interventions for debugging, across synthetic and real datasets. This fusion of variational inference with concept learning yields a flexible, end-to-end trainable framework that enhances explainability without sacrificing accuracy, and it generalizes to arbitrary numbers of factors and concepts, benefiting model debugging and downstream reasoning.

Abstract

Interpretability of Deep Neural Networks using concept-based models offers a promising way to explain model behavior through human-understandable concepts. A parallel line of research focuses on disentangling the data distribution into its underlying generative factors, in turn explaining the data generation process. While both directions have received extensive attention, little work has been done on explaining concepts in terms of generative factors to unify mathematically disentangled representations and human-understandable concepts as an explanation for downstream tasks. In this paper, we propose a novel method CoLiDR - which utilizes a disentangled representation learning setup for learning mutually independent generative factors and subsequently learns to aggregate the said representations into human-understandable concepts using a novel aggregation/decomposition module. Experiments are conducted on datasets with both known and unknown latent generative factors. Our method successfully aggregates disentangled generative factors into concepts while maintaining parity with state-of-the-art concept-based approaches. Quantitative and visual analysis of the learned aggregation procedure demonstrates the advantages of our work compared to commonly used concept-based models over four challenging datasets. Lastly, our work is generalizable to an arbitrary number of concepts and generative factors - making it flexible enough to be suitable for various types of data.

CoLiDR: Concept Learning using Aggregated Disentangled Representations

TL;DR

CoLiDR addresses the need to connect human-understandable concepts with underlying generative factors in data by unifying disentangled representation learning with concept-based explanations. It introduces a three-module framework: a disentangled representations learner (-VAE or VAE), an Aggregation/Decomposition module that forms and inverts concepts from latent factors, and a task predictor that uses concepts to predict labels, with a consistency loss and sparsity regularization to maintain interpretability. The approach supports both annotated and unannotated concepts and demonstrates competitive task performance while enabling test-time interventions for debugging, across synthetic and real datasets. This fusion of variational inference with concept learning yields a flexible, end-to-end trainable framework that enhances explainability without sacrificing accuracy, and it generalizes to arbitrary numbers of factors and concepts, benefiting model debugging and downstream reasoning.

Abstract

Interpretability of Deep Neural Networks using concept-based models offers a promising way to explain model behavior through human-understandable concepts. A parallel line of research focuses on disentangling the data distribution into its underlying generative factors, in turn explaining the data generation process. While both directions have received extensive attention, little work has been done on explaining concepts in terms of generative factors to unify mathematically disentangled representations and human-understandable concepts as an explanation for downstream tasks. In this paper, we propose a novel method CoLiDR - which utilizes a disentangled representation learning setup for learning mutually independent generative factors and subsequently learns to aggregate the said representations into human-understandable concepts using a novel aggregation/decomposition module. Experiments are conducted on datasets with both known and unknown latent generative factors. Our method successfully aggregates disentangled generative factors into concepts while maintaining parity with state-of-the-art concept-based approaches. Quantitative and visual analysis of the learned aggregation procedure demonstrates the advantages of our work compared to commonly used concept-based models over four challenging datasets. Lastly, our work is generalizable to an arbitrary number of concepts and generative factors - making it flexible enough to be suitable for various types of data.
Paper Structure (41 sections, 12 equations, 9 figures, 5 tables)

This paper contains 41 sections, 12 equations, 9 figures, 5 tables.

Figures (9)

  • Figure 1: Schematic overview of the proposed CoLiDR approach. The input data distribution $\mathbf{X}$ is first disentangled into $k$ mutually independent generative factors (GFs). Subsequently, the GFs are aggregated into concepts. Note that the concepts are modeled as a set of annotated concepts corresponding to concepts with annotation by humans and a separate set of concepts that are useful for prediction, but are unannotated. Finally, the concepts are utilized for predicting the task label $\mathbf{Y}$.
  • Figure 2: A schematic view of the underlying assumptions considered across SOTA concept-based models, CBM/CEM, CLAP, GlanceNet, and CoLiDR. The circles in blue represent directly observable attributes, input sample $X$, task label $Y$ and representative concepts $C$. The red circles represent learned representations.
  • Figure 3: Architecture of the proposed CoLiDR approach. CoLiDR consists of three modules - the Disentangled Representations Learning (DRL) Module which learns disentangled generative factors (top), the Aggregation/Decomposition Module which learns to aggregate the generative factors into concepts and subsequently decompose them back into generative factors (bottom-left) and the Task Learning module (bottom-right) which utilizes the concepts to perform task label prediction.
  • Figure 4: GradCAM visualizations of the top-2 highest attributed dimensions with respect to correctly predicted concepts "wavy_hair" (top) and "straight_hair" (bottom) on 2 distinct samples from CelebA.
  • Figure 5: GradCAM visualizations of the top-2 highest attributed dimensions for the correctly predicted concept of shape for (top) d-sprites and (bottom) Shapes3D. The interpolation along the highest contributing dimension shows that the dimension effectively captures the shape of the object in the image.
  • ...and 4 more figures