Table of Contents
Fetching ...

Leveraging Task Structures for Improved Identifiability in Neural Network Representations

Wenlin Chen, Julien Horwood, Juyeon Heo, José Miguel Hernández-Lobato

Abstract

This work extends the theory of identifiability in supervised learning by considering the consequences of having access to a distribution of tasks. In such cases, we show that linear identifiability is achievable in the general multi-task regression setting. Furthermore, we show that the existence of a task distribution which defines a conditional prior over latent factors reduces the equivalence class for identifiability to permutations and scaling of the true latent factors, a stronger and more useful result than linear identifiability. Crucially, when we further assume a causal structure over these tasks, our approach enables simple maximum marginal likelihood optimization, and suggests potential downstream applications to causal representation learning. Empirically, we find that this straightforward optimization procedure enables our model to outperform more general unsupervised models in recovering canonical representations for both synthetic data and real-world molecular data.

Leveraging Task Structures for Improved Identifiability in Neural Network Representations

Abstract

This work extends the theory of identifiability in supervised learning by considering the consequences of having access to a distribution of tasks. In such cases, we show that linear identifiability is achievable in the general multi-task regression setting. Furthermore, we show that the existence of a task distribution which defines a conditional prior over latent factors reduces the equivalence class for identifiability to permutations and scaling of the true latent factors, a stronger and more useful result than linear identifiability. Crucially, when we further assume a causal structure over these tasks, our approach enables simple maximum marginal likelihood optimization, and suggests potential downstream applications to causal representation learning. Empirically, we find that this straightforward optimization procedure enables our model to outperform more general unsupervised models in recovering canonical representations for both synthetic data and real-world molecular data.
Paper Structure (33 sections, 3 theorems, 51 equations, 5 figures, 5 tables, 1 algorithm)

This paper contains 33 sections, 3 theorems, 51 equations, 5 figures, 5 tables, 1 algorithm.

Key Result

Theorem 3.2

Let $\boldsymbol{\theta}\coloneqq(\boldsymbol{\phi},\{\mathbf{w}_{t_i}\}_{i=1}^{N_t})$ and $\boldsymbol{\theta}'\coloneqq(\boldsymbol{\phi}',\{\mathbf{w}_{t_i}'\}_{i=1}^{N_t})$ be any two sets of parameters such that Assume that $\text{Span}(\text{Im}(\mathbf{h}_{\boldsymbol{\phi}}))=\mathbb{R}^d$, i.e., the vectors in the image of the feature extractor $\mathbf{h}_{\boldsymbol{\phi}}$ span the w

Figures (5)

  • Figure 1: Assumed data generating process.
  • Figure 2: The workflow of our proposed method. Shapes are used to track the positions of the ground-truth and recovered latent factors. Colors are used to differentiate between causal and spurious latent factors. We assume that the observed variable is obtained by transforming the ground-truth latent factors with some mixing function. We show that a multi-task regression network (MTRN) can recover the ground-truth latent factors (i.e., data representations) up to linear transformation and further propose a multi-task linear causal model (MTLCM) to reduce the equivalence class for identifiability to permutations and scaling.
  • Figure 3: Identifiability performance for the latent factors learned on the QM9 dataset.
  • Figure 4: Illustration of causal relationships which are captured by our model (a, b) and not captured by our model (c) for the relationships between latent variables and observed target $y$ for a given task. The red arrow in (c) indicates the portion of the graph which is not captured by MTLCM. Note than in (b), the existence of learned regression weights encapsulates this case if the learned weight is zero on the arrow $z_i \to y$. This is depicted with the dashed green arrow.
  • Figure 5: Convergence of the model in the case of transformations of the latent factors for identity, orthogonal and arbitrary linear transformations. Scaled means standardizing the features.

Theorems & Definitions (11)

  • Definition 3.1: Multi-task weak identifiability
  • Theorem 3.2
  • Corollary 3.3
  • Remark 3.4
  • Remark 3.5
  • Remark 3.6
  • Definition 3.7: Strictly strong identifiability
  • Theorem 3.8
  • Remark 3.9
  • proof
  • ...and 1 more