CoLiDR: Concept Learning using Aggregated Disentangled Representations
Sanchit Sinha, Guangzhi Xiong, Aidong Zhang
TL;DR
CoLiDR addresses the need to connect human-understandable concepts with underlying generative factors in data by unifying disentangled representation learning with concept-based explanations. It introduces a three-module framework: a disentangled representations learner ($eta$-VAE or VAE), an Aggregation/Decomposition module that forms and inverts concepts from latent factors, and a task predictor that uses concepts to predict labels, with a consistency loss and sparsity regularization to maintain interpretability. The approach supports both annotated and unannotated concepts and demonstrates competitive task performance while enabling test-time interventions for debugging, across synthetic and real datasets. This fusion of variational inference with concept learning yields a flexible, end-to-end trainable framework that enhances explainability without sacrificing accuracy, and it generalizes to arbitrary numbers of factors and concepts, benefiting model debugging and downstream reasoning.
Abstract
Interpretability of Deep Neural Networks using concept-based models offers a promising way to explain model behavior through human-understandable concepts. A parallel line of research focuses on disentangling the data distribution into its underlying generative factors, in turn explaining the data generation process. While both directions have received extensive attention, little work has been done on explaining concepts in terms of generative factors to unify mathematically disentangled representations and human-understandable concepts as an explanation for downstream tasks. In this paper, we propose a novel method CoLiDR - which utilizes a disentangled representation learning setup for learning mutually independent generative factors and subsequently learns to aggregate the said representations into human-understandable concepts using a novel aggregation/decomposition module. Experiments are conducted on datasets with both known and unknown latent generative factors. Our method successfully aggregates disentangled generative factors into concepts while maintaining parity with state-of-the-art concept-based approaches. Quantitative and visual analysis of the learned aggregation procedure demonstrates the advantages of our work compared to commonly used concept-based models over four challenging datasets. Lastly, our work is generalizable to an arbitrary number of concepts and generative factors - making it flexible enough to be suitable for various types of data.
