Table of Contents
Fetching ...

Masked Thought: Simply Masking Partial Reasoning Steps Can Improve Mathematical Reasoning Learning of Language Models

Changyu Chen, Xiting Wang, Ting-En Lin, Ang Lv, Yuchuan Wu, Xin Gao, Ji-Rong Wen, Rui Yan, Yongbin Li

TL;DR

Problem: LLMs struggle with multi-step mathematical reasoning; existing improvements rely on costly data or larger models. Approach: Masked Thought Fine-Tuning, a simple token-masking regularization during SFT that perturbs chain-of-thought tokens to encourage reliance on the problem statement and earlier steps. Contributions: demonstrated improvements on GSM8K and GSM-IC across multiple base models, provided a general dependency-regularization perspective, and showed compatibility with data augmentation methods; analyzed dependency shifts and case studies. Significance: offers a lightweight, scalable method to boost reasoning in LLMs and informs regularization design for structured reasoning tasks.

Abstract

In reasoning tasks, even a minor error can cascade into inaccurate results, leading to suboptimal performance of large language models in such domains. Earlier fine-tuning approaches sought to mitigate this by leveraging more precise supervisory signals from human labeling, larger models, or self-sampling, although at a high cost. Conversely, we develop a method that avoids external resources, relying instead on introducing perturbations to the input. Our training approach randomly masks certain tokens within the chain of thought, a technique we found to be particularly effective for reasoning tasks. When applied to fine-tuning with GSM8K on Llama-2-7B, this method achieved a 5\% improvement in GSM8K accuracy and a 10\% improvement in GSM-IC accuracy over standard supervised fine-tuning with a few codes modified. Furthermore, it is complementary to existing methods. When integrated with related explicit data augmentation methods, it leads to improvements across five datasets of various augmentation methods, as well as two different base models. We further investigate the mechanisms behind this improvement through case studies and quantitative analysis, suggesting that our approach may provide superior support for the model in capturing long-distance dependencies, especially those related to questions. This enhancement could deepen understanding of the premises in questions and prior steps. Our code is available at Github.

Masked Thought: Simply Masking Partial Reasoning Steps Can Improve Mathematical Reasoning Learning of Language Models

TL;DR

Problem: LLMs struggle with multi-step mathematical reasoning; existing improvements rely on costly data or larger models. Approach: Masked Thought Fine-Tuning, a simple token-masking regularization during SFT that perturbs chain-of-thought tokens to encourage reliance on the problem statement and earlier steps. Contributions: demonstrated improvements on GSM8K and GSM-IC across multiple base models, provided a general dependency-regularization perspective, and showed compatibility with data augmentation methods; analyzed dependency shifts and case studies. Significance: offers a lightweight, scalable method to boost reasoning in LLMs and informs regularization design for structured reasoning tasks.

Abstract

In reasoning tasks, even a minor error can cascade into inaccurate results, leading to suboptimal performance of large language models in such domains. Earlier fine-tuning approaches sought to mitigate this by leveraging more precise supervisory signals from human labeling, larger models, or self-sampling, although at a high cost. Conversely, we develop a method that avoids external resources, relying instead on introducing perturbations to the input. Our training approach randomly masks certain tokens within the chain of thought, a technique we found to be particularly effective for reasoning tasks. When applied to fine-tuning with GSM8K on Llama-2-7B, this method achieved a 5\% improvement in GSM8K accuracy and a 10\% improvement in GSM-IC accuracy over standard supervised fine-tuning with a few codes modified. Furthermore, it is complementary to existing methods. When integrated with related explicit data augmentation methods, it leads to improvements across five datasets of various augmentation methods, as well as two different base models. We further investigate the mechanisms behind this improvement through case studies and quantitative analysis, suggesting that our approach may provide superior support for the model in capturing long-distance dependencies, especially those related to questions. This enhancement could deepen understanding of the premises in questions and prior steps. Our code is available at Github.
Paper Structure (30 sections, 5 equations, 14 figures, 16 tables)

This paper contains 30 sections, 5 equations, 14 figures, 16 tables.

Figures (14)

  • Figure 1: We find MFT has a higher long-distance dependency than SFT. A higher bar indicates greater dependency at the corresponding distance. Specifically. We investigate the dependency of two numerical tokens within a given sequence. We fine-tune Llama-2-7b models on the GSM8K dataset and then count the dependency with 300 samples from the test set.
  • Figure 2: Accuracy on GSM8K when training with Mistral-7B and the MathInstruct dataset of MAmmoTH. MFT shows a higher sample efficiency.
  • Figure 3: The impact of mask ratio when training on Llama2-7B with GSM8K. We compare the MFT of two settings: Fixed mask ratio without warm-up and a linear warm-up of the mask ratio starting from 0.
  • Figure 4: The first step of this problem should be correctly solved as 16 - 3 = 13. We alter the prefix as "16 eggs/day - n eggs/day" to observe the impact of changing 3 to n on the response of SFT and MFT. We give the example of n=1. The orange part is other premises supporting the prediction of the current step.
  • Figure 5: We alter the premise in the question to investigate the impact of changing the word "three" to "five".
  • ...and 9 more figures