Table of Contents
Fetching ...

Teaching Transformers Causal Reasoning through Axiomatic Training

Aniket Vashishtha, Abhinav Kumar, Atharva Pandey, Abbavaram Gowtham Reddy, Kabir Ahuja, Vineeth N Balasubramanian, Amit Sharma

TL;DR

This work introduces axiomatic training as a principled method to teach transformers causal reasoning by learning from symbolic demonstrations of axioms like Transitivity and d-separation. A 67M-parameter model trained from scratch on synthetic axiom triplets generalizes to longer, branched graphs and edge-order perturbations, while rotary positional encodings further enhance length generalization. Extending the approach to finetune Llama-3-8B-Instruct yields substantial gains on causal benchmarks such as CLEAR and Corr2Cause, sometimes surpassing GPT-4. The results establish axiomatic training as a scalable paradigm for enhancing causal reasoning in both small and large language models, with broad implications for knowledge-driven reasoning tasks and verifications in real-world applications.

Abstract

For text-based AI systems to interact in the real world, causal reasoning is an essential skill. Since active interventions are costly, we study to what extent a system can learn causal reasoning from symbolic demonstrations of causal axioms. Specifically, we present an axiomatic training method where the system learns from multiple demonstrations of a causal axiom (or rule), rather than incorporating the axiom as an inductive bias or inferring it from data values. A key question is whether the system would learn to generalize from the axiom demonstrations to more complex scenarios. Our results, based on applying axiomatic training to learn the transitivity axiom and d-separation rule, indicate that such generalization is possible. To avoid data contamination issues, we start with a 67 million parameter transformer model and train it from scratch. On both tasks, we find that a model trained on linear causal chains (along with some noisy variations) can generalize well to complex graphs, including longer causal chains, causal chains with reversed order, and graphs with branching.To handle diverse text inputs, the same method is extended to finetune language models. Finetuning Llama-3-8B-Instruct model on our axiomatic data leads to significant gains on causal benchmarks such as Corr2Cause and CLEAR, in some cases providing state-of-the-art performance surpassing GPT-4.

Teaching Transformers Causal Reasoning through Axiomatic Training

TL;DR

This work introduces axiomatic training as a principled method to teach transformers causal reasoning by learning from symbolic demonstrations of axioms like Transitivity and d-separation. A 67M-parameter model trained from scratch on synthetic axiom triplets generalizes to longer, branched graphs and edge-order perturbations, while rotary positional encodings further enhance length generalization. Extending the approach to finetune Llama-3-8B-Instruct yields substantial gains on causal benchmarks such as CLEAR and Corr2Cause, sometimes surpassing GPT-4. The results establish axiomatic training as a scalable paradigm for enhancing causal reasoning in both small and large language models, with broad implications for knowledge-driven reasoning tasks and verifications in real-world applications.

Abstract

For text-based AI systems to interact in the real world, causal reasoning is an essential skill. Since active interventions are costly, we study to what extent a system can learn causal reasoning from symbolic demonstrations of causal axioms. Specifically, we present an axiomatic training method where the system learns from multiple demonstrations of a causal axiom (or rule), rather than incorporating the axiom as an inductive bias or inferring it from data values. A key question is whether the system would learn to generalize from the axiom demonstrations to more complex scenarios. Our results, based on applying axiomatic training to learn the transitivity axiom and d-separation rule, indicate that such generalization is possible. To avoid data contamination issues, we start with a 67 million parameter transformer model and train it from scratch. On both tasks, we find that a model trained on linear causal chains (along with some noisy variations) can generalize well to complex graphs, including longer causal chains, causal chains with reversed order, and graphs with branching.To handle diverse text inputs, the same method is extended to finetune language models. Finetuning Llama-3-8B-Instruct model on our axiomatic data leads to significant gains on causal benchmarks such as Corr2Cause and CLEAR, in some cases providing state-of-the-art performance surpassing GPT-4.
Paper Structure (35 sections, 4 equations, 5 figures, 12 tables)

This paper contains 35 sections, 4 equations, 5 figures, 12 tables.

Figures (5)

  • Figure 1: Axiomatic Training for imparting causal reasoning to language models. Given an axiom, we construct a training dataset comprising <premise, hypothesis, conclusion> triplets based on simple chain-like graphs of 3-6 nodes. A transformer model trained from scratch on such instances generalizes to much more complex graphs, including longer causal chains with >6 nodes, branched networks with higher average in-degree and out-degree, complete reversals, shuffled statements, and longer node names. Moreover, when an existing model such as Llama-3-8B-Instruct model is trained on the same dataset, it leads to significant (up to 20 percentage point (p.p.)) improvement in accuracy on causal reasoning benchmarks such as CLEAR and Corr2Cause.
  • Figure 2: Evaluating generalization on causal sequences (without random flipping) with longer node names (than the ones used in sequences in train set). TS-2 training set with no positional encoding leads to the best performance. Refer table \ref{['tab:nodename_length_gen']} for complete results.
  • Figure 3: Generalizing to longer unseen causal sequences ($>$6 nodes) with random flipping on TS2 and OCC (with NoPE) train sets. OCC-trained models struggle due to limited edge-level variability, while TS2 NoPE consistently performs well. Refer table \ref{['tab:simple_complex_linear']} for complete results
  • Figure A1: Example instance of Multiple Choice (MC) question type from chen2024clearlanguagemodelsreally dataset describing d-separation rule problem defined with a different hypothesis type and semantic structure than the one our models are fine-tuned on.
  • Figure A2: Example instance from the Corr2Cause dataset, where the model must infer the presence of a collider between variables given only correlational and conditional independence statements.

Theorems & Definitions (2)

  • Definition 3.1: Causal Irrelevance, adapted from Defn. 7 in GALLES19979
  • Definition 3.3: Definition 1.2.3 in pearlbook