Table of Contents
Fetching ...

Counterfactual reasoning: an analysis of in-context emergence

Moritz Miller, Bernhard Schölkopf, Siyuan Guo

TL;DR

This work investigates in-context counterfactual reasoning in language models using a synthetic linear-regression setup with unobserved noise and a latent parameter $\theta$. It shows that Transformer models can perform counterfactual reasoning, aided by a designated noise abduction mechanism and a linear representation of $\theta$ in the residual stream; both self-attention and model depth are key drivers, and pre-training data diversity strongly influences emergence and generalization. The authors provide theoretical identifiability results under exchangeability and demonstrate the approach on cyclic sequential dynamics via Lotka-Volterra SDEs, suggesting potential for counterfactual story generation. The study combines mechanistic analysis (noise abduction heads, residual probes) with extensive experiments, offering a principled view of how in-context counterfactual reasoning can arise in large language-model architectures and pointing toward safe, exploratory AI capabilities.

Abstract

Large-scale neural language models exhibit remarkable performance in in-context learning: the ability to learn and reason about the input context on the fly. This work studies in-context counterfactual reasoning in language models, that is, the ability to predict consequences of a hypothetical scenario. We focus on a well-defined, synthetic linear regression task that requires noise abduction. Accurate prediction is based on (1) inferring an unobserved latent concept and (2) copying contextual noise from factual observations. We show that language models are capable of counterfactual reasoning. Further, we enhance existing identifiability results and reduce counterfactual reasoning for a broad class of functions to a transformation on in-context observations. In Transformers, we find that self-attention, model depth and pre-training data diversity drive performance. Moreover, we provide mechanistic evidence that the latent concept is linearly represented in the residual stream and we introduce designated \textit{noise abduction heads} central to performing counterfactual reasoning. Lastly, our findings extend to counterfactual reasoning under SDE dynamics and reflect that Transformers can perform noise abduction on sequential data, providing preliminary evidence on the potential for counterfactual story generation. Our code is available under https://github.com/mrtzmllr/iccr.

Counterfactual reasoning: an analysis of in-context emergence

TL;DR

This work investigates in-context counterfactual reasoning in language models using a synthetic linear-regression setup with unobserved noise and a latent parameter . It shows that Transformer models can perform counterfactual reasoning, aided by a designated noise abduction mechanism and a linear representation of in the residual stream; both self-attention and model depth are key drivers, and pre-training data diversity strongly influences emergence and generalization. The authors provide theoretical identifiability results under exchangeability and demonstrate the approach on cyclic sequential dynamics via Lotka-Volterra SDEs, suggesting potential for counterfactual story generation. The study combines mechanistic analysis (noise abduction heads, residual probes) with extensive experiments, offering a principled view of how in-context counterfactual reasoning can arise in large language-model architectures and pointing toward safe, exploratory AI capabilities.

Abstract

Large-scale neural language models exhibit remarkable performance in in-context learning: the ability to learn and reason about the input context on the fly. This work studies in-context counterfactual reasoning in language models, that is, the ability to predict consequences of a hypothetical scenario. We focus on a well-defined, synthetic linear regression task that requires noise abduction. Accurate prediction is based on (1) inferring an unobserved latent concept and (2) copying contextual noise from factual observations. We show that language models are capable of counterfactual reasoning. Further, we enhance existing identifiability results and reduce counterfactual reasoning for a broad class of functions to a transformation on in-context observations. In Transformers, we find that self-attention, model depth and pre-training data diversity drive performance. Moreover, we provide mechanistic evidence that the latent concept is linearly represented in the residual stream and we introduce designated \textit{noise abduction heads} central to performing counterfactual reasoning. Lastly, our findings extend to counterfactual reasoning under SDE dynamics and reflect that Transformers can perform noise abduction on sequential data, providing preliminary evidence on the potential for counterfactual story generation. Our code is available under https://github.com/mrtzmllr/iccr.

Paper Structure

This paper contains 38 sections, 6 theorems, 46 equations, 12 figures, 1 table.

Key Result

Lemma 1

Suppose $y = T(f(x), u)$ for some function $T: \mathcal{Y} \times \mathcal{U} \longrightarrow \mathcal{Y}$. Assume for any fixed $f(x) \in \mathcal{Y}$, the inverse ${{T^{-1}}}(f(x), \cdot)$ exists for all $y$, i.e., $u = {{T^{-1}}}(f(x),y)$. Then, counterfactual reasoning reduces to learning a tran where $h(f(x^\text{CF}),f(x),y) = T(f(x^\text{CF}),T^{-1}(f(x), y))$ operates on elements of $\math

Figures (12)

  • Figure 1: In-context counterfactual reasoning. Training on a corpus of sequences that come from a mixture of distributions (each $\bullet$ on the far left represents a single sequence from a distinct distribution parameterized by $\theta$). Suppose each observation satisfies $y = f(x, u_y)$ for some noise $u_y$. An in-context sequence $\bullet$ takes the form of $n$ examples. This is concatenated with index token $z$ referring back to observed factual observation $(x_j, y_j)$ when $z=j$ and the hypothetical new information $x^\text{CF}$: $(x_1, y_1, \ldots, x_n, y_n, z, x^\text{CF})$. In-context counterfactual reasoning can be measured via accurate prediction on $y^\text{CF}$. Accurate prediction requires noise abduction from factual observation, that is, to infer $u_y$ consistent with $(x_j, y_j)$, and prediction based on the intervention$x^\text{CF}$ and inferred $u_y$.
  • Figure 2: Model comparisons for in-context counterfactual reasoning. In-context counterfactual prediction accuracy measured via log-transformed ${{ \mathrm{MSE}}}$ averaged over $6400$ sequences versus the number of in-context examples observed in a prompt. We compare GPT-2 (standard), LSTM, GRU, and Elman RNN. Though not significant, GPT-2 achieves lowest error and fastest convergence rate for a small number of in-context examples. For more than $14$ in-context examples, the Elman RNN obtains significantly efron1979basicbootstrap higher in-context ${{ \mathrm{MSE}}}$ than the three other architectures.
  • Figure 3: Attention and model depth matter.
  • Figure 4: The Transformer linearly encodes the latent parameter. We train a linear probe on $6400$ prompts from a fresh evaluation set after every layer. We evaluate on $1280$ sequences. All layers after the first one encode relevant information for predicting $\theta$ from the residual stream only.
  • Figure 5: Data diversity and emergence. In-context counterfactual reasoning emerges at training. To generalize to unknown $\theta$, the model is to be trained on a sufficiently diverse pre-training corpus.
  • ...and 7 more figures

Theorems & Definitions (12)

  • Definition 1: Exchangeable sequence
  • Lemma 1: Counterfactual reasoning as transformation on observed values
  • Theorem 1: Counterfactual identifiability under exchangeability
  • proof : Proof of posterior predictive distribution
  • proof : Mean and variance in linear additive framework
  • proof : Proof of Lemma \ref{['lmm:retrieval']}
  • Definition 2: Definition 6.1, Equivalence, nasr2023counterfactualidentifiability
  • Proposition 1: Proposition 6.2, nasr2023counterfactualidentifiability
  • Lemma 2: Lemma B.2, nasr2023counterfactualidentifiability
  • Theorem 2: Theorem 5.1, nasr2023counterfactualidentifiability
  • ...and 2 more