Table of Contents
Fetching ...

Arithmetic Reasoning with LLM: Prolog Generation & Permutation

Xiaocheng Yang, Bingsen Chen, Yik-Cheung Tam

TL;DR

This work addresses arithmetic reasoning in LLMs by offloading calculation to a Prolog-based framework. The authors train LLMs to generate Prolog programs that encode problem facts and rules, solving them with an external interpreter, thereby reducing cascaded calculation errors inherent to Chain-of-Thought methods. They introduce the GSM8K-Prolog dataset and ProPer, a permutation-based data augmentation that exploits Prolog's predicate order independence to boost robustness. Across three 7B LLMs with LoRA fine-tuning, Prolog-based generation outperforms CoT, and ProPer provides additional gains, demonstrating that separating reasoning from computation plus order-invariant predicate representations yields stronger arithmetic problem-solving capabilities.

Abstract

Instructing large language models (LLMs) to solve elementary school math problems has shown great success using Chain of Thought (CoT). However, the CoT approach relies on an LLM to generate a sequence of arithmetic calculations which can be prone to cascaded calculation errors. We hypothesize that an LLM should focus on extracting predicates and generating symbolic formulas from the math problem description so that the underlying calculation can be done via an external code interpreter. We investigate using LLM to generate Prolog programs to solve mathematical questions. Experimental results show that our Prolog-based arithmetic problem-solving outperforms CoT generation in the GSM8K benchmark across three distinct LLMs. In addition, given the insensitive ordering of predicates and symbolic formulas in Prolog, we propose to permute the ground truth predicates for more robust LLM training via data augmentation.

Arithmetic Reasoning with LLM: Prolog Generation & Permutation

TL;DR

This work addresses arithmetic reasoning in LLMs by offloading calculation to a Prolog-based framework. The authors train LLMs to generate Prolog programs that encode problem facts and rules, solving them with an external interpreter, thereby reducing cascaded calculation errors inherent to Chain-of-Thought methods. They introduce the GSM8K-Prolog dataset and ProPer, a permutation-based data augmentation that exploits Prolog's predicate order independence to boost robustness. Across three 7B LLMs with LoRA fine-tuning, Prolog-based generation outperforms CoT, and ProPer provides additional gains, demonstrating that separating reasoning from computation plus order-invariant predicate representations yields stronger arithmetic problem-solving capabilities.

Abstract

Instructing large language models (LLMs) to solve elementary school math problems has shown great success using Chain of Thought (CoT). However, the CoT approach relies on an LLM to generate a sequence of arithmetic calculations which can be prone to cascaded calculation errors. We hypothesize that an LLM should focus on extracting predicates and generating symbolic formulas from the math problem description so that the underlying calculation can be done via an external code interpreter. We investigate using LLM to generate Prolog programs to solve mathematical questions. Experimental results show that our Prolog-based arithmetic problem-solving outperforms CoT generation in the GSM8K benchmark across three distinct LLMs. In addition, given the insensitive ordering of predicates and symbolic formulas in Prolog, we propose to permute the ground truth predicates for more robust LLM training via data augmentation.
Paper Structure (11 sections, 1 equation, 4 figures, 2 tables)

This paper contains 11 sections, 1 equation, 4 figures, 2 tables.

Figures (4)

  • Figure 1: Overview of Prolog generation for arithmetic reasoning with large language models.
  • Figure 2: Prolog and permuted Prolog code samples.
  • Figure 2: Accuracy(%) results on GSM8K with different permutation ratios. We report both the best and average accuracy of 1:1 and 1:2 over three trials with different randomly permuted data in the form of max (avg). Note that the 1:0 case essentially means not applying ProPer.
  • Figure 3: Validation loss curves for training Llama2, CodeLlama, and Mistral with different permutation ratios (We only report the first trial when we use permuted data since the loss curves are very similar across trials).