Table of Contents
Fetching ...

C$^2$DLM: Causal Concept-Guided Diffusion Large Language Models

Kairong Han, Nuanqiao Shan, Ziyu Zhao, Zijing Hu, Xinpeng Dong, Junjian Ye, Lujia Pan, Fei Wu, Kun Kuang

TL;DR

The paper introduces C$^2$DLM, a diffusion-language-model framework that injects causal priors by constructing concept-level causal graphs via a teacher LLM and enforcing alignment with these priors through a V-aware re-attention mechanism. This approach targets the reasoning bottlenecks of AR and DLM paradigms, delivering notable gains on synthetic and real reasoning tasks and achieving substantial training speedups. The proposed causal-knowledge extraction and attention alignment yield improved robustness to reasoning perturbations and demonstrate the value of two-dimensional supervision signals in guiding language models toward the data-generating process. Limitations include scalability and reliance on pretraining stages, pointing to future work in scaling causal priors and pretraining with causal supervision.

Abstract

Autoregressive (AR) language models and Diffusion Language Models (DLMs) constitute the two principal paradigms of large language models. However, both paradigms suffer from insufficient reasoning capabilities. Human reasoning inherently relies on causal knowledge and thought, which are reflected in natural language. But in the AR paradigm, language is modeled as next token prediction (a strictly left-to-right, token-by-token order), whereas natural language itself exhibits more flexible causal structures. In the DLM paradigm, the attention mechanism is fully connected, which entirely disregards causal order. To fill this gap, we propose a \underline{\textbf{C}}ausal \underline{\textbf{C}}oncept-Guided \underline{\textbf{D}}iffusion \underline{\textbf{L}}anguage \underline{\textbf{M}}odel (C$^2$DLM). Starting from DLM's fully connected attention, C$^2$DLM first obtains a concept-level causal graph from the teacher model, and then explicitly guides attention to learn causal relationships between concepts. By focusing on causal relationships and avoiding interference from difficult subgoals involving causal inversion, C$^2$DLM improves 12\% with about 3.2 times training speedup in the COT-OrderPerturb task, and achieves an average gain of 1.31\% across six downstream reasoning tasks. More details in the repository ~\href{https://github.com/Kairong-Han/C-2-DLM}{here}.

C$^2$DLM: Causal Concept-Guided Diffusion Large Language Models

TL;DR

The paper introduces CDLM, a diffusion-language-model framework that injects causal priors by constructing concept-level causal graphs via a teacher LLM and enforcing alignment with these priors through a V-aware re-attention mechanism. This approach targets the reasoning bottlenecks of AR and DLM paradigms, delivering notable gains on synthetic and real reasoning tasks and achieving substantial training speedups. The proposed causal-knowledge extraction and attention alignment yield improved robustness to reasoning perturbations and demonstrate the value of two-dimensional supervision signals in guiding language models toward the data-generating process. Limitations include scalability and reliance on pretraining stages, pointing to future work in scaling causal priors and pretraining with causal supervision.

Abstract

Autoregressive (AR) language models and Diffusion Language Models (DLMs) constitute the two principal paradigms of large language models. However, both paradigms suffer from insufficient reasoning capabilities. Human reasoning inherently relies on causal knowledge and thought, which are reflected in natural language. But in the AR paradigm, language is modeled as next token prediction (a strictly left-to-right, token-by-token order), whereas natural language itself exhibits more flexible causal structures. In the DLM paradigm, the attention mechanism is fully connected, which entirely disregards causal order. To fill this gap, we propose a \underline{\textbf{C}}ausal \underline{\textbf{C}}oncept-Guided \underline{\textbf{D}}iffusion \underline{\textbf{L}}anguage \underline{\textbf{M}}odel (CDLM). Starting from DLM's fully connected attention, CDLM first obtains a concept-level causal graph from the teacher model, and then explicitly guides attention to learn causal relationships between concepts. By focusing on causal relationships and avoiding interference from difficult subgoals involving causal inversion, CDLM improves 12\% with about 3.2 times training speedup in the COT-OrderPerturb task, and achieves an average gain of 1.31\% across six downstream reasoning tasks. More details in the repository ~\href{https://github.com/Kairong-Han/C-2-DLM}{here}.

Paper Structure

This paper contains 26 sections, 18 equations, 9 figures, 7 tables, 1 algorithm.

Figures (9)

  • Figure 1: Difference between AR, DLM, and C$^2$DLM. AR models struggle to capture global information, and linguistic flexibility is not bound to a strict left-to-right, token-by-token causal order. DLMs discard causal priors entirely. The C$^2$DLM explicitly guides the model to learn causal relations between concepts, capturing the underlying causal priors of natural language generation.
  • Figure 2: (a) Leveraging the contextual learning capability of a strong model, the causal teacher model uses prompts to automatically extract concept-level information from CoTs and generates causal meta-knowledge links between concepts as supervisory signals. (b) During training, for the internal attention map obtained from CoTs, the V-aware Re-attention mechanism weights the attention maps by the norms of the corresponding Value matrix. (c) The tokenizer maps the textual supervisory signals from step (a) to the weighted attention maps, and a loss-based intervention is applied to guide the C$^2$DLM’s decision-making process.
  • Figure 3: Normal COT follows the causal topological order of the data-generating process to construct reasoning steps, whereas the Shuffle setting simulates cases where COT exhibits causal misordering.
  • Figure 4: Accuracy curve as training progresses in the COT-OrderPerturb task.
  • Figure 5: Performance change curve during different epochs of training on the STG_H dataset.
  • ...and 4 more figures