Table of Contents
Fetching ...

Can Large Language Models Learn Independent Causal Mechanisms?

Gaël Gendron, Bao Trung Nguyen, Alex Yuxuan Peng, Michael Witbrock, Gillian Dobbie

TL;DR

The paper investigates whether Large Language Models can learn Independent Causal Mechanisms by introducing Independent Causal Language Models (ICLM) that route inputs to domain-specific modules while sharing a domain-invariant module. It combines unsupervised routing via vector quantisation with an information-theoretic Mutual Information minimisation to encourage modular independence and abstraction. The approach is theoretically motivated and empirically evaluated on abstract and causal reasoning tasks (ACRE and RAVEN), showing improved out-of-distribution generalisation and partial independence between modules, with domain-invariant knowledge contributing broadly while domain-specific modules specialise. The findings suggest that principled modularity can enhance robustness to distribution shifts in LLMs, though complete independence is not achieved and the approach incurs substantial compute, motivating future work on richer causal graphs and more scalable routing strategies.

Abstract

Despite impressive performance on language modelling and complex reasoning tasks, Large Language Models (LLMs) fall short on the same tasks in uncommon settings or with distribution shifts, exhibiting a lack of generalisation ability. By contrast, systems such as causal models, that learn abstract variables and causal relationships, can demonstrate increased robustness against changes in the distribution. One reason for this success is the existence and use of Independent Causal Mechanisms (ICMs) representing high-level concepts that only sparsely interact. In this work, we apply two concepts from causality to learn ICMs within LLMs. We develop a new LLM architecture composed of multiple sparsely interacting language modelling modules. We show that such causal constraints can improve out-of-distribution performance on abstract and causal reasoning tasks. We also investigate the level of independence and domain specialisation and show that LLMs rely on pre-trained partially domain-invariant mechanisms resilient to fine-tuning.

Can Large Language Models Learn Independent Causal Mechanisms?

TL;DR

The paper investigates whether Large Language Models can learn Independent Causal Mechanisms by introducing Independent Causal Language Models (ICLM) that route inputs to domain-specific modules while sharing a domain-invariant module. It combines unsupervised routing via vector quantisation with an information-theoretic Mutual Information minimisation to encourage modular independence and abstraction. The approach is theoretically motivated and empirically evaluated on abstract and causal reasoning tasks (ACRE and RAVEN), showing improved out-of-distribution generalisation and partial independence between modules, with domain-invariant knowledge contributing broadly while domain-specific modules specialise. The findings suggest that principled modularity can enhance robustness to distribution shifts in LLMs, though complete independence is not achieved and the approach incurs substantial compute, motivating future work on richer causal graphs and more scalable routing strategies.

Abstract

Despite impressive performance on language modelling and complex reasoning tasks, Large Language Models (LLMs) fall short on the same tasks in uncommon settings or with distribution shifts, exhibiting a lack of generalisation ability. By contrast, systems such as causal models, that learn abstract variables and causal relationships, can demonstrate increased robustness against changes in the distribution. One reason for this success is the existence and use of Independent Causal Mechanisms (ICMs) representing high-level concepts that only sparsely interact. In this work, we apply two concepts from causality to learn ICMs within LLMs. We develop a new LLM architecture composed of multiple sparsely interacting language modelling modules. We show that such causal constraints can improve out-of-distribution performance on abstract and causal reasoning tasks. We also investigate the level of independence and domain specialisation and show that LLMs rely on pre-trained partially domain-invariant mechanisms resilient to fine-tuning.
Paper Structure (36 sections, 8 equations, 17 figures, 7 tables)

This paper contains 36 sections, 8 equations, 17 figures, 7 tables.

Figures (17)

  • Figure 1: Proposed Independent Causal Language Models (ICLM) architecture for language-modelling tasks. The input text (on the left, in blue) is fed to multiple pretrained LLM modules (in red). A router uses clustering on input text embeddings (in purple) to activate a domain-specific module for this input. The domain-invariant module is always activated. The latent representations generated by the activated modules are combined using an aggregation scheme (in orange) and converted into a probability distribution for the next word (on the right, in blue). An additional loss (in green) minimises the Mutual Information between the domain-invariant and the domain-specific representations. The router ensures that the domain-specific modules only gain in-domain knowledge while the Mutual Information loss regularises the domain-invariant module towards learning abstract representations.
  • Figure 2: Simplified temporal causal graph $\mathcal{G}$ during training before adding Mutual Information minimisation. C is the input context. $H_R$, $H_I$, $H_{S_n}$, $H_S$ are the latent states of the router, domain-invariant, domain-specific and activated domain-specific (after router weighting) modules. For simplicity, we only show the state $H_{S_n}$ of the activated domain-specific module $n$. $Y$ and $Y_{true}$ are the output and true distributions. $W_R$, $W_{S_n}$ and $W_I$ are the trainable parameters of the modules. $\mathcal{L}_Y = \mathcal{L}_o + \alpha \cdot \mathcal{L}_{inv} + \beta \cdot \mathcal{L}_{dom}$ and $\mathcal{L}_R$ are the output and router losses. Black edges show the forward pass at step $\tau$. Blue dashed edges show the backward pass at step $\tau$. Red dotted edges illustrate the causal links between the forward and backward passes.
  • Figure 3: Evolution of independence measures between modules during fine-tuning on ACRE and RAVEN. We measure independence on the last hidden states of the modules. Correlation and MI are highly reduced but modules remain correlated.
  • Figure 4: Correlation between the last hidden states of the modules during inference at test time. Module states are more correlated than during training.
  • Figure 5: 2D projection of the hidden states of LLaMA2 on ACRE and RAVEN i.i.d and o.o.d sets. Ground truth samples are labelled as in Table \ref{['tab:abstract_reasoning_specific_results']} (text/symbolic i.i.d/o.o.d sets). Text and symbolic inputs are always clustered separately. i.i.d and o.o.d sets are clustered together in ACRE and separated in RAVEN. The router follows the text and symbolic division.
  • ...and 12 more figures