Table of Contents
Fetching ...

TapWeight: Reweighting Pretraining Objectives for Task-Adaptive Pretraining

Ruiyi Zhang, Sai Ashish Somayajula, Pengtao Xie

TL;DR

TapWeight is proposed, a task-adaptive pretraining framework which automatically determines the optimal importance of each pretraining objective based on downstream feedback and reweights each pretraining objective by solving a multi-level optimization problem.

Abstract

Large-scale general domain pretraining followed by downstream-specific finetuning has become a predominant paradigm in machine learning. However, discrepancies between the pretraining and target domains can still lead to performance degradation in certain cases, underscoring the need for task-adaptive continued pretraining (TAP). TAP methods typically involve continued pretraining on task-specific unlabeled datasets or introducing additional unsupervised learning objectives to enhance model capabilities. While many TAP methods perform continued pretraining with multiple pretraining objectives, they often determine the tradeoff parameters between objectives manually, resulting in suboptimal outcomes and higher computational costs. In this paper, we propose TapWeight, a task-adaptive pretraining framework which automatically determines the optimal importance of each pretraining objective based on downstream feedback. TapWeight reweights each pretraining objective by solving a multi-level optimization problem. We applied TapWeight to both molecular property prediction and natural language understanding tasks, significantly surpassing baseline methods. Experimental results validate the effectiveness and generalizability of TapWeight.

TapWeight: Reweighting Pretraining Objectives for Task-Adaptive Pretraining

TL;DR

TapWeight is proposed, a task-adaptive pretraining framework which automatically determines the optimal importance of each pretraining objective based on downstream feedback and reweights each pretraining objective by solving a multi-level optimization problem.

Abstract

Large-scale general domain pretraining followed by downstream-specific finetuning has become a predominant paradigm in machine learning. However, discrepancies between the pretraining and target domains can still lead to performance degradation in certain cases, underscoring the need for task-adaptive continued pretraining (TAP). TAP methods typically involve continued pretraining on task-specific unlabeled datasets or introducing additional unsupervised learning objectives to enhance model capabilities. While many TAP methods perform continued pretraining with multiple pretraining objectives, they often determine the tradeoff parameters between objectives manually, resulting in suboptimal outcomes and higher computational costs. In this paper, we propose TapWeight, a task-adaptive pretraining framework which automatically determines the optimal importance of each pretraining objective based on downstream feedback. TapWeight reweights each pretraining objective by solving a multi-level optimization problem. We applied TapWeight to both molecular property prediction and natural language understanding tasks, significantly surpassing baseline methods. Experimental results validate the effectiveness and generalizability of TapWeight.

Paper Structure

This paper contains 60 sections, 14 equations, 2 figures, 5 tables.

Figures (2)

  • Figure 1: An Overview of TapWeight. In the first stage, the model undergoes multi-objective pretraining with fixed tradeoff ratios between objectives. In the second stage, the pretrained model is finetuned on the training split of the downstream dataset. In the third stage, the finetuned model is evaluated on the validation split of the downstream dataset to compute a loss, and the trainable tradeoff parameters fixed in the first stage are learned by minimizing this validation loss.
  • Figure 2: Evolution of the tradeoff parameter $\lambda$ over the training steps of TapWeight on the following downstream datasets: (a) Esol, (b) Lipo, (c) Freesolv, (d) Tox21, (e) Toxcast, and (f) Clintox.