Learning Interpretable Concepts: Unifying Causal Representation Learning and Foundation Models
Goutham Rajendran, Simon Buchholz, Bryon Aragam, Bernhard Schölkopf, Pradeep Ravikumar
TL;DR
The paper addresses how to learn human-interpretable concepts from complex data by unifying causal representation learning (CRL) with foundation-model interpretability. It defines concepts as affine subspaces in a latent representation and proves identifiability of a subset of concepts from a small number of environments, specifically requiring only $m= n+1$ concept-conditioned datasets (i.e., $n+2$ environments) to recover $n$ atomic concepts up to linear transformations. The authors validate the theory with end-to-end contrastive learning on synthetic data and extend the framework to large-language-model alignment, introducing steering matrices to guide truthfulness in Inference-Time Intervention (ITI) and demonstrating improvements on TruthfulQA with LLaMA. This work provides a principled partial identifiability framework for interpretable representations in high-dimensional data and offers practical mechanisms for controllable generation and mechanistic interpretability of foundation models.
Abstract
To build intelligent machine learning systems, there are two broad approaches. One approach is to build inherently interpretable models, as endeavored by the growing field of causal representation learning. The other approach is to build highly-performant foundation models and then invest efforts into understanding how they work. In this work, we relate these two approaches and study how to learn human-interpretable concepts from data. Weaving together ideas from both fields, we formally define a notion of concepts and show that they can be provably recovered from diverse data. Experiments on synthetic data and large language models show the utility of our unified approach.
