Model Predictive Task Sampling for Efficient and Robust Adaptation
Qi Wang, Zehao Xiao, Yixiu Mao, Yun Qu, Jiayi Shen, Yiqin Lv, Xiangyang Ji
TL;DR
Model Predictive Task Sampling (MPTS) introduces a risk-aware, model-based framework to efficiently prioritize difficult tasks for robust adaptation in zero-shot, few-shot, and SFT settings. A lightweight risk predictive model (RPM) based on a variational autoencoder and streaming variational inference forecasts task-specific adaptation risk, enabling amortized, uncertainty-aware task selection via a UCB-like acquisition. Theoretical guarantees show approximate invariance of task difficulty ranking across iterations and convergence under diminishing rank-flipping, while experiments across sinusoid regression, few-shot image classification with foundation models, Meta-RL, and robotic domain randomization demonstrate improved tail-risk robustness and learning efficiency versus state-of-the-art baselines. The framework is plug-and-play across backbones and tasks, offering scalable robustness for foundation-model adaptation and large-scale decision-making with reduced environment interactions and annotation costs.
Abstract
Foundation models have revolutionized general-purpose problem-solving, offering rapid task adaptation through pretraining, meta-training, and finetuning. Recent crucial advances in these paradigms reveal the importance of challenging task prioritized sampling to enhance adaptation robustness under distribution shifts. However, ranking task difficulties over iteration as a preliminary step typically requires exhaustive task evaluation, which is practically unaffordable in computation and data-annotation. This study provides a novel perspective to illuminate the possibility of leveraging the dual importance of adaptation robustness and learning efficiency, particularly in scenarios where task evaluation is risky or costly, such as iterative agent-environment interactions for robotic policy evaluation or computationally intensive inference steps for finetuning foundation models. Firstly, we introduce Model Predictive Task Sampling (MPTS), a framework that bridges the task space and adaptation risk distributions, providing a theoretical foundation for robust active task sampling. MPTS employs a generative model to characterize the episodic optimization process and predicts task-specific adaptation risk via posterior inference. The resulting risk predictive model amortizes the costly evaluation of task adaptation performance and provably approximates task difficulty rankings. MPTS seamlessly integrates into zero-shot, few-shot, and supervised finetuning settings. Empirically, we conduct extensive experiments in pattern recognition using foundation models and sequential decision-making. Our results demonstrate that MPTS significantly enhances adaptation robustness for tail risk or out-of-distribution (OOD) tasks and improves learning efficiency compared to state-of-the-art (SoTA) methods. The code is available at the project site https://github.com/thu-rllab/MPTS.
