Table of Contents
Fetching ...

Robust Domain Generalisation with Causal Invariant Bayesian Neural Networks

Gaël Gendron, Michael Witbrock, Gillian Dobbie

TL;DR

A Bayesian neural architecture is proposed that disentangles the learning of the the data distribution from the inference process mechanisms and approximates reasoning under causal interventions on out-of-distribution image recognition tasks where the data distribution acts as strong adversarial confounders.

Abstract

Deep neural networks can obtain impressive performance on various tasks under the assumption that their training domain is identical to their target domain. Performance can drop dramatically when this assumption does not hold. One explanation for this discrepancy is the presence of spurious domain-specific correlations in the training data that the network exploits. Causal mechanisms, in the other hand, can be made invariant under distribution changes as they allow disentangling the factors of distribution underlying the data generation. Yet, learning causal mechanisms to improve out-of-distribution generalisation remains an under-explored area. We propose a Bayesian neural architecture that disentangles the learning of the the data distribution from the inference process mechanisms. We show theoretically and experimentally that our model approximates reasoning under causal interventions. We demonstrate the performance of our method, outperforming point estimate-counterparts, on out-of-distribution image recognition tasks where the data distribution acts as strong adversarial confounders.

Robust Domain Generalisation with Causal Invariant Bayesian Neural Networks

TL;DR

A Bayesian neural architecture is proposed that disentangles the learning of the the data distribution from the inference process mechanisms and approximates reasoning under causal interventions on out-of-distribution image recognition tasks where the data distribution acts as strong adversarial confounders.

Abstract

Deep neural networks can obtain impressive performance on various tasks under the assumption that their training domain is identical to their target domain. Performance can drop dramatically when this assumption does not hold. One explanation for this discrepancy is the presence of spurious domain-specific correlations in the training data that the network exploits. Causal mechanisms, in the other hand, can be made invariant under distribution changes as they allow disentangling the factors of distribution underlying the data generation. Yet, learning causal mechanisms to improve out-of-distribution generalisation remains an under-explored area. We propose a Bayesian neural architecture that disentangles the learning of the the data distribution from the inference process mechanisms. We show theoretically and experimentally that our model approximates reasoning under causal interventions. We demonstrate the performance of our method, outperforming point estimate-counterparts, on out-of-distribution image recognition tasks where the data distribution acts as strong adversarial confounders.
Paper Structure (28 sections, 15 equations, 8 figures, 1 table)

This paper contains 28 sections, 15 equations, 8 figures, 1 table.

Figures (8)

  • Figure 1: Target causal representation and actual causal graph during training. $X$ is the input image and $Y$ is the output class. $Z$ is the variable representing the factors of distribution generating the input $X$, it is composed of domain-specific factors $Z_S$ and robust domain-invariant factors $Z_R$. $U_{XY}$ represents the shared factors of variations between $X$ and $Y$. The variables in red are unobserved. $S$ is a selection variable determining the current domain. The target causal graph is different from the one used in supervised deep learning as it omits the influence of the datasets $\mathcal{D} = \{\mathcal{D}_X,\mathcal{D}_Y\}$. The datasets depend on the same factors as $X$ and $Y$ and add spurious correlations that are not captured when using the target representation.
  • Figure 2: Architecture of the Causal-Invariant Bayesian (CIB) neural network. At the top (in blue), a variational encoder generates the respective parameters of the distributions of the intermediate representations $R$ and $\{R_i'\}_{i=1}^N$ of the input $X$ and the contextual information $\{X_i'\}_{i=1}^N$. $R$ and $\{R_i'\}_{i=1}^N$ are then provided to the inference module (in orange) that retrieves $Y$. This procedure aims to disentangle the learning of the representation $R$ from the learning of the inference mechanisms and force the inference module to only learn the latter. The weights $W$ of the inference module are sampled from a distribution learned using variational inference (in green). The weight sampling and the variational encoding are regularised using an ELBO loss.
  • Figure 3: Domain transfer results on the OFFICEHOME dataset. Each row represents the category of the training subset and each column represents the category of the test subset. Accuracy with a random guess is 0.015. In the right figure, a cell is shown in green if its value is higher than the baseline on the left. The mean and standard deviation across three runs are shown. Our proposed model (on the right) systematically outperforms the baseline (on the left). Only the model trained on the Art domain shows little to no improvement. As this domain demonstrates the lowest i.i.d and o.o.d accuracy, we explain it by the lack of exploitable training data.
  • Figure 4: Evolution of the validation loss during training. The mean and standard deviation across three runs are shown. After a period of improvement, the ResNet-18 and ResNet-18-CT baselines overfit as the training progresses. This behaviour is not observed with CIBResNet-18, which also demonstrates a lower standard deviation.
  • Figure 5: Accuracy heatmap of CIBResNet-18 on the CIFAR-10 test set as a function of the number of weight and context samples. The amount of weight samples has a negligible effect on performance. Only increasing the context size improves accuracy.
  • ...and 3 more figures

Theorems & Definitions (2)

  • proof
  • proof