Table of Contents
Fetching ...

DAST: Difficulty-Aware Self-Training on Large Language Models

Boyang Xue, Qi Zhu, Hongru Wang, Rui Wang, Sheng Wang, Hongling Xu, Fei Mi, Yasheng Wang, Lifeng Shang, Qun Liu, Kam-Fai Wong

TL;DR

DAST tackles under-sampling of difficult queries in LLM self-training by introducing a difficulty-aware loop that estimates query difficulty $d_i$, augments training data accordingly, and refines the model via SFT and DPO. It uses a sampling-based estimation with an initial policy $\mathcal{M}_0$ and a difficulty partition across levels $E$, $M$, $H$, and $U$, along with data proportion control and difficulty-matched prompting to adjust response lengths. Empirical results on GSM8K, MATH, TAL-SCQ, College, and TheoremQA show improved math reasoning and generalization, particularly on out-of-domain tasks, with DAST-S and DAST-D outperforming baselines. The findings highlight the importance of explicitly incorporating task difficulty into self-training to achieve data-efficient gains in large language models.

Abstract

Present Large Language Models (LLM) self-training methods always under-sample on challenging queries, leading to inadequate learning on difficult problems which limits LLMs' ability. Therefore, this work proposes a difficulty-aware self-training (DAST) framework that focuses on improving both the quantity and quality of self-generated responses on challenging queries during self-training. DAST is specified in three components: 1) sampling-based difficulty level estimation, 2) difficulty-aware data augmentation, and 3) the self-training algorithm using SFT and DPO respectively. Experiments on mathematical tasks demonstrate the effectiveness and generalization of DAST, highlighting the critical role of difficulty-aware strategies in advancing LLM self-training.

DAST: Difficulty-Aware Self-Training on Large Language Models

TL;DR

DAST tackles under-sampling of difficult queries in LLM self-training by introducing a difficulty-aware loop that estimates query difficulty , augments training data accordingly, and refines the model via SFT and DPO. It uses a sampling-based estimation with an initial policy and a difficulty partition across levels , , , and , along with data proportion control and difficulty-matched prompting to adjust response lengths. Empirical results on GSM8K, MATH, TAL-SCQ, College, and TheoremQA show improved math reasoning and generalization, particularly on out-of-domain tasks, with DAST-S and DAST-D outperforming baselines. The findings highlight the importance of explicitly incorporating task difficulty into self-training to achieve data-efficient gains in large language models.

Abstract

Present Large Language Models (LLM) self-training methods always under-sample on challenging queries, leading to inadequate learning on difficult problems which limits LLMs' ability. Therefore, this work proposes a difficulty-aware self-training (DAST) framework that focuses on improving both the quantity and quality of self-generated responses on challenging queries during self-training. DAST is specified in three components: 1) sampling-based difficulty level estimation, 2) difficulty-aware data augmentation, and 3) the self-training algorithm using SFT and DPO respectively. Experiments on mathematical tasks demonstrate the effectiveness and generalization of DAST, highlighting the critical role of difficulty-aware strategies in advancing LLM self-training.

Paper Structure

This paper contains 38 sections, 3 equations, 3 figures, 4 tables, 1 algorithm.

Figures (3)

  • Figure 1: Changes of data proportion and response length distribution of samples in different difficulty levels during a three-round self-training process. The vanilla rejection sampling to construct training data (a) is widely employed in singh2024humandatascalingselftraininggulcehre2023reinforcedselftrainingrestlanguagesordoni2023jointzelikman2022starbootstrappingreasoningreasoning. (b) and (c) are the proposed DAST aim to control data proportion and response lengths for challenging queries. Note that in iteration 0, the training data $\mathcal{D}_{\mathrm u}$ is the original dataset $\mathcal{D}_{\mathrm o}$ with ground-truth labels, while during iteration 1, 2, and 3, the training data is combined of self-generated data $\mathcal{D}_{\mathrm a}$ and the original dataset $\mathcal{D}_{\mathrm o}$. All the difficulty levels are measured on the initial policy $\mathcal{M}_0$ on the GSM8K test set and are fixed during self-training.
  • Figure 2: Performance results of DAST over various baselines on both in-domain (ID) and out-of-domain (OOD) mathematical test sets using Llama-3.1. Note that the names of employed baselines are in lowercase.
  • Figure 3: Results of data proportion control.