Table of Contents
Fetching ...

Intervening to Learn and Compose Causally Disentangled Representations

Alex Markham, Isaac Hirsch, Jeri A. Chang, Liam Solus, Bryon Aragam

Abstract

In designing generative models, it is commonly believed that in order to learn useful latent structure, we face a fundamental tension between expressivity and structure. In this paper we challenge this view by proposing a new approach to training arbitrarily expressive generative models that simultaneously learn causally disentangled concepts. This is accomplished by adding a simple context module to an arbitrarily complex black-box model, which learns to process concept information by implicitly inverting linear representations from the model's encoder. Inspired by the notion of intervention in a causal model, our module selectively modifies its architecture during training, allowing it to learn a compact joint model over different contexts. We show how adding this module leads to causally disentangled representations that can be composed for out-of-distribution generation on both real and simulated data. The resulting models can be trained end-to-end or fine-tuned from pre-trained models. To further validate our proposed approach, we prove a new identifiability result that extends existing work on identifying structured representations.

Intervening to Learn and Compose Causally Disentangled Representations

Abstract

In designing generative models, it is commonly believed that in order to learn useful latent structure, we face a fundamental tension between expressivity and structure. In this paper we challenge this view by proposing a new approach to training arbitrarily expressive generative models that simultaneously learn causally disentangled concepts. This is accomplished by adding a simple context module to an arbitrarily complex black-box model, which learns to process concept information by implicitly inverting linear representations from the model's encoder. Inspired by the notion of intervention in a causal model, our module selectively modifies its architecture during training, allowing it to learn a compact joint model over different contexts. We show how adding this module leads to causally disentangled representations that can be composed for out-of-distribution generation on both real and simulated data. The resulting models can be trained end-to-end or fine-tuned from pre-trained models. To further validate our proposed approach, we prove a new identifiability result that extends existing work on identifying structured representations.

Paper Structure

This paper contains 57 sections, 6 theorems, 20 equations, 10 figures, 15 tables.

Key Result

theorem 1

Assume that the rows of each $C_j$ come from a linearly independent set and $f$ is injective and differentiable. Then, given single-node interventions on each concept $\mathbf{c}_j$, the representations $C_j$ and latent concept distribution $p(\mathbf{c})$ are identifiable. $\blacktriangleleft$$\bla

Figures (10)

  • Figure 1: Overview (see \ref{['eq:decoder:layers']} for notation map). Given a black-box encoder-decoder architecture (blue), we propose to prepend a context module (red) to the decoder. (left) Instead of passing the output of the encoder directly to the decoder (blue box + blue arrow), the embeddings $\mathbf{e}$ are passed through the context module, consisting of three distinct layers. The output of this module is then passed into the decoder. (right) The model learns to compose different concepts OOD, e.g., object and background colour, to values that never appear together in the training data.
  • Figure 2: (a) Examples of concept learning and composition in MNIST (left) and 3DIdent (right) using CM (end-to-end), our context module augmenting the base NVAE---these images are generated, not reconstructed. The first row shows generated observational samples; The second and third rows show samples of learned concepts, generated by intervening on individual concepts, e.g., 'Concept 1' is 'scaled' for MNIST; The final row shows OOD composition, generated by simultaneously intervening on pairs of learned concepts. (b) Observational samples from the context module vs. the base model show no perceptual loss.
  • Figure 3: Example images obtained from the quad dataset. Each subfigure corresponds to a different context, indicated by the title (obs = observational; the others indicate a single-node intervention). For example, in quad1, the first (top left) quadrant has been manipulated from orange-green hues to blue-red hues.
  • Figure 4: Example images obtained from the quad dataset. Each subfigure corresponds to a different double-concept intervention context, indicated by the title. For example, in quad2_quad3, the second (top-right) and third (bottom-left) quadrants have been manipulated from orange-green hues to blue-red hues.
  • Figure 5: Example generated OOD images from a run of quad.
  • ...and 5 more figures

Theorems & Definitions (11)

  • definition 1: bengio2013deepthomas_disentangling_2017
  • remark 1
  • remark 2
  • theorem 1: Identifiability
  • theorem 2
  • lemma 1
  • lemma 2
  • proposition 1
  • proof
  • proposition 2
  • ...and 1 more