Table of Contents
Fetching ...

Integrating Arithmetic Learning Improves Mathematical Reasoning in Smaller Models

Neeraj Gangwar, Suma P Bhat, Nickvash Kani

TL;DR

This work addresses the difficulty of achieving strong mathematical reasoning in small language models by introducing an explicitly arithmetic-centric training regimen. It investigates two strategies: intermediate fine-tuning on a programmatically generated arithmetic dataset before downstream reasoning fine-tuning, and incorporating the arithmetic data into the instruction-tuning mixture. Across GSM8k and several out-of-domain benchmarks, both approaches improve arithmetic grounding and reasoning performance, with intermediate fine-tuning enhancing out-of-domain generalization and arithmetic accuracy within reasoning, and arithmetic-inclusive instruction tuning boosting few-shot math reasoning and robustness to numerical perturbations. The findings highlight explicit arithmetic training as a key pathway to strengthening mathematical reasoning in smaller models and provide publicly available datasets and code to enable further exploration.

Abstract

While large models pre-trained on high-quality data exhibit excellent performance across various reasoning tasks, including mathematical reasoning (e.g. GSM8k, MultiArith), specializing smaller models to excel at mathematical reasoning remains a challenging problem. Common approaches to address this challenge include knowledge distillation, where smaller student models learn from large pre-trained teacher models, and data augmentation, such as rephrasing questions. Despite these efforts, smaller models struggle with arithmetic computations, leading to errors in mathematical reasoning. In this work, we focus on leveraging a programmatically generated arithmetic dataset to enhance the reasoning capabilities of smaller models. We investigate two key approaches to incorporate this dataset -- (1) intermediate fine-tuning, where a model is fine-tuned on the arithmetic dataset before being trained on a reasoning dataset, and (2) integrating the arithmetic dataset into the instruction-tuning mixture, allowing the model to learn arithmetic skills alongside general instruction-following abilities. Our experiments on multiple reasoning benchmarks demonstrate that incorporating an arithmetic dataset, whether through targeted fine-tuning or within the instruction-tuning mixture, enhances the models' arithmetic capabilities, which in turn improves their mathematical reasoning performance.

Integrating Arithmetic Learning Improves Mathematical Reasoning in Smaller Models

TL;DR

This work addresses the difficulty of achieving strong mathematical reasoning in small language models by introducing an explicitly arithmetic-centric training regimen. It investigates two strategies: intermediate fine-tuning on a programmatically generated arithmetic dataset before downstream reasoning fine-tuning, and incorporating the arithmetic data into the instruction-tuning mixture. Across GSM8k and several out-of-domain benchmarks, both approaches improve arithmetic grounding and reasoning performance, with intermediate fine-tuning enhancing out-of-domain generalization and arithmetic accuracy within reasoning, and arithmetic-inclusive instruction tuning boosting few-shot math reasoning and robustness to numerical perturbations. The findings highlight explicit arithmetic training as a key pathway to strengthening mathematical reasoning in smaller models and provide publicly available datasets and code to enable further exploration.

Abstract

While large models pre-trained on high-quality data exhibit excellent performance across various reasoning tasks, including mathematical reasoning (e.g. GSM8k, MultiArith), specializing smaller models to excel at mathematical reasoning remains a challenging problem. Common approaches to address this challenge include knowledge distillation, where smaller student models learn from large pre-trained teacher models, and data augmentation, such as rephrasing questions. Despite these efforts, smaller models struggle with arithmetic computations, leading to errors in mathematical reasoning. In this work, we focus on leveraging a programmatically generated arithmetic dataset to enhance the reasoning capabilities of smaller models. We investigate two key approaches to incorporate this dataset -- (1) intermediate fine-tuning, where a model is fine-tuned on the arithmetic dataset before being trained on a reasoning dataset, and (2) integrating the arithmetic dataset into the instruction-tuning mixture, allowing the model to learn arithmetic skills alongside general instruction-following abilities. Our experiments on multiple reasoning benchmarks demonstrate that incorporating an arithmetic dataset, whether through targeted fine-tuning or within the instruction-tuning mixture, enhances the models' arithmetic capabilities, which in turn improves their mathematical reasoning performance.

Paper Structure

This paper contains 42 sections, 6 figures, 6 tables.

Figures (6)

  • Figure 1: An example from the GSM8k test set and its solutions generated by (Top) FlanT5-Large directly fine-tuned on the GSM8k dataset, and (Bottom) FlanT5-Large fine-tuned on an arithmetic dataset before training it on the GSM8k dataset.
  • Figure 2: GSM8k arithmetic accuracy (Left) and GSM8k test accuracy (Right) of the models fine-tuned on GSM8k after the intermediate fine-tuning for different number of epochs. The GSM8k arithmetic performance saturates after two epochs of intermediate fine-tuning and shows no significant improvement, leading to no further improvement in the GSM8k performance.
  • Figure 3: GSM8k arithmetic accuracy or the ability of the models fine-tuned on GSM8k to generate the results of arithmetic computations correctly in reasoning contexts. This evaluation is performed on the GSM8k test set.
  • Figure 4: Performance of the pre-trained and instruction-tuned GPT2-Large models on GSM-Plus for different perturbation types using self-consistency decoding. The model fine-tuned on Tülu 3 SFT mixture and the arithmetic dataset performs better across different perturbation types. The percentages above the bars represent the performance drop relative to the original GSM8k dataset, as shown in Table \ref{['tab:inst_tuning_results']}.
  • Figure 5: Performance of the pre-trained and instruction-tuned GPT2-Large models on GSM-Plus for different perturbation types using greedy decoding. The model fine-tuned on Tülu 3 SFT mixture and the arithmetic dataset performs better across different perturbation types. The percentages above the bars represent the performance drop relative to the original GSM8k dataset, as shown in Table \ref{['tab:inst_tuning_results']}.
  • ...and 1 more figures