Table of Contents
Fetching ...

Understanding Reasoning Ability of Language Models From the Perspective of Reasoning Paths Aggregation

Xinyi Wang, Alfonso Amayuelas, Kexun Zhang, Liangming Pan, Wenhu Chen, William Yang Wang

TL;DR

This work proposes that an LM can view an LM as deriving new conclusions by aggregating indirect reasoning paths seen at pre-training time, and formalizes the reasoning paths as random walk paths on the knowledge/reasoning graphs.

Abstract

Pre-trained language models (LMs) are able to perform complex reasoning without explicit fine-tuning. To understand how pre-training with a next-token prediction objective contributes to the emergence of such reasoning capability, we propose that we can view an LM as deriving new conclusions by aggregating indirect reasoning paths seen at pre-training time. We found this perspective effective in two important cases of reasoning: logic reasoning with knowledge graphs (KGs) and chain-of-thought (CoT) reasoning. More specifically, we formalize the reasoning paths as random walk paths on the knowledge/reasoning graphs. Analyses of learned LM distributions suggest that a weighted sum of relevant random walk path probabilities is a reasonable way to explain how LMs reason. Experiments and analysis on multiple KG and CoT datasets reveal the effect of training on random walk paths and suggest that augmenting unlabeled random walk reasoning paths can improve real-world multi-step reasoning performance. code: https://github.com/WANGXinyiLinda/LM_random_walk

Understanding Reasoning Ability of Language Models From the Perspective of Reasoning Paths Aggregation

TL;DR

This work proposes that an LM can view an LM as deriving new conclusions by aggregating indirect reasoning paths seen at pre-training time, and formalizes the reasoning paths as random walk paths on the knowledge/reasoning graphs.

Abstract

Pre-trained language models (LMs) are able to perform complex reasoning without explicit fine-tuning. To understand how pre-training with a next-token prediction objective contributes to the emergence of such reasoning capability, we propose that we can view an LM as deriving new conclusions by aggregating indirect reasoning paths seen at pre-training time. We found this perspective effective in two important cases of reasoning: logic reasoning with knowledge graphs (KGs) and chain-of-thought (CoT) reasoning. More specifically, we formalize the reasoning paths as random walk paths on the knowledge/reasoning graphs. Analyses of learned LM distributions suggest that a weighted sum of relevant random walk path probabilities is a reasonable way to explain how LMs reason. Experiments and analysis on multiple KG and CoT datasets reveal the effect of training on random walk paths and suggest that augmenting unlabeled random walk reasoning paths can improve real-world multi-step reasoning performance. code: https://github.com/WANGXinyiLinda/LM_random_walk
Paper Structure (20 sections, 2 theorems, 17 equations, 7 figures, 2 tables, 1 algorithm)

This paper contains 20 sections, 2 theorems, 17 equations, 7 figures, 2 tables, 1 algorithm.

Key Result

Proposition 2.1

If LM effectively learned the random walk data distribution through pre-training, we have

Figures (7)

  • Figure 1: We hypothesize that the pre-training corpus can be viewed as generated from random walks on a reasoning graph over world knowledge/concepts. With each node $s_i$ representing concepts, $p_j$ can be viewed as arguments that connect them. Then we hypothesize that a language model (LM) training on such a corpus can be viewed as reasoning by a weighted aggregation of random walk paths that connect the entities in interest. $P_{\text{LM}}$ denote the LM distribution while $P_D$ denotes the random walk probability from the pre-training corpus. $w_i^1$ denotes the weight assigned to the first random walk path by the LM for argument $p_i$, and $w_i^2$ denotes the weight assigned to the second random walk path.
  • Figure 2: KL divergence between various reference distributions and LM distribution, with different maximum random walk lengths, averaged over Countries (top) and UMLS (bottom) testing set, respectively. The rows correspond to the LM distribution $P_{\text{LM}}(e_2|e_1, r)$ with maximum pre-training random walk path lengths ($L_{max}$) ranging from 1 to 10. From left to right, the columns correspond to the weighted aggregation distribution $P_w(e_2|e_1, r)$ with maximum random walk path lengths ($N_{max}$) from 1 to 10, the unweighted aggregation distribution $P_s(e_2|e_1, r)$ with maximum random walk path lengths ($N_{max}$) from 1 to 10, the reference distribution $P^*(e_2|e_1, r)$, and the uniform distribution $P_u(e_2)$, respectively. A darker color represents a smaller KL value, meaning that the two distributions are closer. In general, $KL[P_w, P_{\text{LM}}]$ is always smaller than $KL[P_s, P_{\text{LM}}]$, which implies that LM is learning the difference in rule importance. $KL[P^*, P_{\text{LM}}]$ and $KL[P_u, P_{\text{LM}}]$ serve as anchor points to show the scale of KL values. $KL[P^*, P_{\text{LM}}]$ is generally high because the probability mass concentrates on correct answers, thus it can be very different from the LM distribution. Thus $KL[P^*, P_{\text{LM}}]$ shows how peaky the LM distribution is, and $KL[P_u, P_{\text{LM}}]$ shows how flat the LM distribution is.
  • Figure 3: Testing accuracy w.r.t. various maximum pre-training random walk lengths ($1 \leq L_{max} \leq 10$) on Countries (left) and UMLS (right) datasets, respectively. For Countries, the LM ($P_{\text{LM}}$) performance converges to the weighted aggregation ($P_w$) performance, while for UMLS, LM consistently outperforms both weighted ($P_w$) and unweighted ($P_s$) aggregation performance. This is likely because LM ($P_{\text{LM}}$) can learn a better logical rule weighting scheme than weighted aggregation ($P_w$) in more complex KGs.
  • Figure 4: Testing accuracy of LM trained on different random walk path lengths. Each line corresponds to a different KG dataset and thus is not directly comparable. We want to highlight the common trend here that each line peaks at some optimal path length.
  • Figure 5: Testing accuracy of continue pre-training with our random walk paths of different length $L_{max}$. Each line corresponds to a different MWP dataset and thus is not directly comparable. We want to highlight the common trend here that each line would peak at some optimal path length range, which is similar to \ref{['fig:logic-len-acc']}.
  • ...and 2 more figures

Theorems & Definitions (3)

  • Proposition 2.1
  • Proposition 1.1
  • proof