Table of Contents
Fetching ...

Towards Causal Representation Learning

Bernhard Schölkopf, Francesco Locatello, Stefan Bauer, Nan Rosemary Ke, Nal Kalchbrenner, Anirudh Goyal, Yoshua Bengio

TL;DR

This paper argues that bridging machine learning with causal inference through causal representation learning can address core limitations of current AI, notably robustness and transfer across distribution shifts. It develops a unifying view centered on Independent Causal Mechanisms, the level of causal modeling, and the pursuit of learning causal variables from high-dimensional data using modular, intervention-friendly representations. It outlines concrete research directions in learning disentangled causal representations, transferable mechanisms, interventional world models, and systematic application to SSL, RL, and science. The proposed framework aims to enable more robust generalization, faster transfer, and safer, more interpretable AI by aligning learning with the underlying causal structure of the world.

Abstract

The two fields of machine learning and graphical causality arose and developed separately. However, there is now cross-pollination and increasing interest in both fields to benefit from the advances of the other. In the present paper, we review fundamental concepts of causal inference and relate them to crucial open problems of machine learning, including transfer and generalization, thereby assaying how causality can contribute to modern machine learning research. This also applies in the opposite direction: we note that most work in causality starts from the premise that the causal variables are given. A central problem for AI and causality is, thus, causal representation learning, the discovery of high-level causal variables from low-level observations. Finally, we delineate some implications of causality for machine learning and propose key research areas at the intersection of both communities.

Towards Causal Representation Learning

TL;DR

This paper argues that bridging machine learning with causal inference through causal representation learning can address core limitations of current AI, notably robustness and transfer across distribution shifts. It develops a unifying view centered on Independent Causal Mechanisms, the level of causal modeling, and the pursuit of learning causal variables from high-dimensional data using modular, intervention-friendly representations. It outlines concrete research directions in learning disentangled causal representations, transferable mechanisms, interventional world models, and systematic application to SSL, RL, and science. The proposed framework aims to enable more robust generalization, faster transfer, and safer, more interpretable AI by aligning learning with the underlying causal structure of the world.

Abstract

The two fields of machine learning and graphical causality arose and developed separately. However, there is now cross-pollination and increasing interest in both fields to benefit from the advances of the other. In the present paper, we review fundamental concepts of causal inference and relate them to crucial open problems of machine learning, including transfer and generalization, thereby assaying how causality can contribute to modern machine learning research. This also applies in the opposite direction: we note that most work in causality starts from the premise that the causal variables are given. A central problem for AI and causality is, thus, causal representation learning, the discovery of high-level causal variables from low-level observations. Finally, we delineate some implications of causality for machine learning and propose key research areas at the intersection of both communities.

Paper Structure

This paper contains 42 sections, 16 equations, 3 figures, 1 table.

Figures (3)

  • Figure 1: Difference between statistical (left) and causal models (right) on a given set of three variables. While a statistical model specifies a single probability distribution, a causal model represents a set of distributions, one for each possible intervention (indicated with a in the figure).
  • Figure 2: Illustration of the causal representation learning problem setting. Perceptual data, such as images or other high-dimensional sensor measurements, can be thought of as entangled views of the state of an unknown causal system as described in (\ref{['eq:causal_rep_learning']}). With the exception of possible task labels, none of the variables describing the causal variables generating the system may be known. The goal of causal representation learning is to learn a representation (partially) exposing this unknown causal structure (e.g., which variables describe the system, and their relations). As full recovery may often be unreasonable, neural networks may map the low-level features to some high-level variables supporting causal statements relevant to a set of downstream tasks of interest. For example, if the task is to detect the manipulable objects in a scene, the representation may separate intrinsic object properties from their pose and appearance to achieve robustness to distribution shifts on the latter variables. Usually, we do not get labels for the high-level variables, but the properties of causal models can serve as useful inductive biases for learning (e.g., the \ref{['pri:scsh']}).
  • Figure 3: Example of the \ref{['pri:scsh']} where an intervention (which may or may not be intentional/observed) changes the position of one finger ( ), and as a consequence, the object falls. The change in pixel space is entangled (or distributed), in contrast to the change in the causal model.