Table of Contents
Fetching ...

Predictive Coding beyond Correlations

Tommaso Salvatori, Luca Pinchetti, Amine M'Charrak, Beren Millidge, Thomas Lukasiewicz

TL;DR

The paper investigates extending highly influential predictive coding (PC) networks beyond correlation-based inference to full causal reasoning. It shows how PC graphs can perform interventions without mutilating the graph by manipulating prediction errors, effectively implementing do-operations at runtime, and extends this to learn causal structures from observational data via continuous adjacency weights with acyclicity and sparsity priors. The authors provide theoretical links to Structural Causal Models and validate them through extensive experiments on synthetic DAGs and image classification tasks, demonstrating associational, interventional, and counterfactual reasoning, with improvements in MNIST/FashionMNIST when using interventional queries. Overall, this work bridges computational neuroscience and causality, proposing an end-to-end, transparent causal engine that jointly learns graph structure and performs causal queries, while acknowledging limitations related to Markov equivalence and data constraints.

Abstract

Recently, there has been extensive research on the capabilities of biologically plausible algorithms. In this work, we show how one of such algorithms, called predictive coding, is able to perform causal inference tasks. First, we show how a simple change in the inference process of predictive coding enables to compute interventions without the need to mutilate or redefine a causal graph. Then, we explore applications in cases where the graph is unknown, and has to be inferred from observational data. Empirically, we show how such findings can be used to improve the performance of predictive coding in image classification tasks, and conclude that such models are able to perform simple end-to-end causal inference tasks.

Predictive Coding beyond Correlations

TL;DR

The paper investigates extending highly influential predictive coding (PC) networks beyond correlation-based inference to full causal reasoning. It shows how PC graphs can perform interventions without mutilating the graph by manipulating prediction errors, effectively implementing do-operations at runtime, and extends this to learn causal structures from observational data via continuous adjacency weights with acyclicity and sparsity priors. The authors provide theoretical links to Structural Causal Models and validate them through extensive experiments on synthetic DAGs and image classification tasks, demonstrating associational, interventional, and counterfactual reasoning, with improvements in MNIST/FashionMNIST when using interventional queries. Overall, this work bridges computational neuroscience and causality, proposing an end-to-end, transparent causal engine that jointly learns graph structure and performs causal queries, while acknowledging limitations related to Markov equivalence and data constraints.

Abstract

Recently, there has been extensive research on the capabilities of biologically plausible algorithms. In this work, we show how one of such algorithms, called predictive coding, is able to perform causal inference tasks. First, we show how a simple change in the inference process of predictive coding enables to compute interventions without the need to mutilate or redefine a causal graph. Then, we explore applications in cases where the graph is unknown, and has to be inferred from observational data. Empirically, we show how such findings can be used to improve the performance of predictive coding in image classification tasks, and conclude that such models are able to perform simple end-to-end causal inference tasks.
Paper Structure (31 sections, 1 theorem, 30 equations, 24 figures, 5 tables, 1 algorithm)

This paper contains 31 sections, 1 theorem, 30 equations, 24 figures, 5 tables, 1 algorithm.

Key Result

Proposition 3.1

Let $\mathcal{G}$ be a PC graph, with structure given by a directed acyclic graph $G$ with variables $\{\mathbf{x}_1, \dots, \mathbf{x}_N\}$, as defined in Sec. 2.1. Then, the distribution of the variables obtained after the following two operations are equivalent:

Figures (24)

  • Figure 1: Example socio-economic graph and its structure after conditioning and intervening on education level.
  • Figure 2: $(a)$ PC graph with the same causal structure of that in Fig. \ref{['fig:causal_graph']}. Every vertex $v_i$ is associated with a value node $\mathbf{x}_i$, and an error node $\mathbf{e}_i$. The arrows show the influence of every node to the others: the prediction information follows the direction of the arrows of the original graph, while the error information goes backwards. $(b)$ Example of conditioning in PC graphs. We fix the value of $\mathbf{x}_3$, making the effect of all the arrows entering $v_3$ irrelevant, as $\mathbf{x}_3$ is fixed and hence ignores incoming information. This, however, does not apply to error information going out from $v_3$, which keeps influencing $\mathbf{x}_1$ and $\mathbf{x}_2$; this is solved in $(c)$ Example of an intervention in PC graphs. According to Pearl's causal theory, the do-operator on a node ($v_3$ in this case) removes the incoming edges, to avoid the newly introduced information to flow backwards and influence the parent nodes. As in PC, the only information flowing opposite to the causal relations is the error information, an intervention can be performed by removing (or setting to zero) the error node.
  • Figure 3: What would $\mathbf{x}_4$ be, had $\mathbf{x}_3$ been equal to $\mathbf{s}^*_3$ in situation $U = \mathbf{u}$? This figure provides an example of the three-step process to perform counterfactuals, using a structural causal model with four exogenous and four endogenous variables. We are given two kinds of data: the original values of $\mathbf{x}_1,\dots, \mathbf{x}_4$, which correspond to past information, here denoted by $\mathbf{s}_1,\dots, \mathbf{s}_4$, and the intervention information $\mathbf{s}_3^*$, needed to understand the what would have happened to $\mathbf{x}_4$ if we had changed $\mathbf{s}_3$ to $\mathbf{s}_3^*$?. The final answer corresponds to the node $\tilde{\mathbf{x}}_4$ obtained in the prediction step.
  • Figure 4: (a) How to compute a prediction given a data point on a fully connected PC graph, using interventional queries. (b) Left to right: causal structure of the SCM. Convergence behavior of PC energy vs. error metric (MAE), during SCM learning for butterfly graph. Error (by node) of interventional query estimates on $\mathbf{x}_3$ (yellow node). Error (by node) of counterfactual query estimates with intervention on $\mathbf{x}_3$ given factual data (blue nodes). (c) Architecture used to reconstruct counterfactual images. $U_Z$ corresponds to the color of the digit, $T$ to the rotation angle, $X$ to the input, $Y$ to the colored and rotated image. The architecture represented is a predictive coding network that resembles the architecture described in the main paper, where the clusters of neurons represent nodes. Arrows represent transformations achieved via MLPs. The reconstructions on the right demonstrate the robustness of our model with respect to intervention to (up) rotation angle and (down) color by transforming each input digit to the desired target.The colored digits show that our method is robust when performing interventions on the rotation angle. The table shows that the model performance does not depend on the choice of the number of neurons $d_h$ for the node $\mathbf{u}_z$. The work proposing the experiment de2022deep reports an MSE of $0.001$.
  • Figure 5: (a) Experiments on structure learning from synthetic data, generated from Erdős-Rényi and scale-free random graphs with $20$ nodes. On the left, the connection strength of the true graph; on the right, the one learned by a PC graph. (b) Structure learning on the $2$-MNIST dataset: the plot shows the weights of the adjacency matrix $\mathbf{A}$ over the number of epochs, the dotted curve the test accuracy. The vertical lines $C_i$, refer to the connectivities discovered by the structure learning algorithm during training. Such connectivities are shown on the right side of the plot, where you can see the $6$ clusters of neurons, and the connections among them. For example, the blue one (representing the direct connection), immediately goes to 1, and stays there until the second vertical line (that represents C2), and then starts decreasing. At epoch 250, there are two curves above it: the ones of the hierarchical connections. (c) A description of the two energy functions optimized by the PC graph when training on negative and non-negative examples. (d) Table with test error of all experiments performed on MNIST and FashionMNIST, averaged over three seeds. The best results are obtained when augmenting the training process with both the proposed structure learning methods.
  • ...and 19 more figures

Theorems & Definitions (2)

  • Proposition 3.1
  • proof : Proof of Theorem 1