Table of Contents
Fetching ...

Model Predictive Task Sampling for Efficient and Robust Adaptation

Qi Wang, Zehao Xiao, Yixiu Mao, Yun Qu, Jiayi Shen, Yiqin Lv, Xiangyang Ji

TL;DR

Model Predictive Task Sampling (MPTS) introduces a risk-aware, model-based framework to efficiently prioritize difficult tasks for robust adaptation in zero-shot, few-shot, and SFT settings. A lightweight risk predictive model (RPM) based on a variational autoencoder and streaming variational inference forecasts task-specific adaptation risk, enabling amortized, uncertainty-aware task selection via a UCB-like acquisition. Theoretical guarantees show approximate invariance of task difficulty ranking across iterations and convergence under diminishing rank-flipping, while experiments across sinusoid regression, few-shot image classification with foundation models, Meta-RL, and robotic domain randomization demonstrate improved tail-risk robustness and learning efficiency versus state-of-the-art baselines. The framework is plug-and-play across backbones and tasks, offering scalable robustness for foundation-model adaptation and large-scale decision-making with reduced environment interactions and annotation costs.

Abstract

Foundation models have revolutionized general-purpose problem-solving, offering rapid task adaptation through pretraining, meta-training, and finetuning. Recent crucial advances in these paradigms reveal the importance of challenging task prioritized sampling to enhance adaptation robustness under distribution shifts. However, ranking task difficulties over iteration as a preliminary step typically requires exhaustive task evaluation, which is practically unaffordable in computation and data-annotation. This study provides a novel perspective to illuminate the possibility of leveraging the dual importance of adaptation robustness and learning efficiency, particularly in scenarios where task evaluation is risky or costly, such as iterative agent-environment interactions for robotic policy evaluation or computationally intensive inference steps for finetuning foundation models. Firstly, we introduce Model Predictive Task Sampling (MPTS), a framework that bridges the task space and adaptation risk distributions, providing a theoretical foundation for robust active task sampling. MPTS employs a generative model to characterize the episodic optimization process and predicts task-specific adaptation risk via posterior inference. The resulting risk predictive model amortizes the costly evaluation of task adaptation performance and provably approximates task difficulty rankings. MPTS seamlessly integrates into zero-shot, few-shot, and supervised finetuning settings. Empirically, we conduct extensive experiments in pattern recognition using foundation models and sequential decision-making. Our results demonstrate that MPTS significantly enhances adaptation robustness for tail risk or out-of-distribution (OOD) tasks and improves learning efficiency compared to state-of-the-art (SoTA) methods. The code is available at the project site https://github.com/thu-rllab/MPTS.

Model Predictive Task Sampling for Efficient and Robust Adaptation

TL;DR

Model Predictive Task Sampling (MPTS) introduces a risk-aware, model-based framework to efficiently prioritize difficult tasks for robust adaptation in zero-shot, few-shot, and SFT settings. A lightweight risk predictive model (RPM) based on a variational autoencoder and streaming variational inference forecasts task-specific adaptation risk, enabling amortized, uncertainty-aware task selection via a UCB-like acquisition. Theoretical guarantees show approximate invariance of task difficulty ranking across iterations and convergence under diminishing rank-flipping, while experiments across sinusoid regression, few-shot image classification with foundation models, Meta-RL, and robotic domain randomization demonstrate improved tail-risk robustness and learning efficiency versus state-of-the-art baselines. The framework is plug-and-play across backbones and tasks, offering scalable robustness for foundation-model adaptation and large-scale decision-making with reduced environment interactions and annotation costs.

Abstract

Foundation models have revolutionized general-purpose problem-solving, offering rapid task adaptation through pretraining, meta-training, and finetuning. Recent crucial advances in these paradigms reveal the importance of challenging task prioritized sampling to enhance adaptation robustness under distribution shifts. However, ranking task difficulties over iteration as a preliminary step typically requires exhaustive task evaluation, which is practically unaffordable in computation and data-annotation. This study provides a novel perspective to illuminate the possibility of leveraging the dual importance of adaptation robustness and learning efficiency, particularly in scenarios where task evaluation is risky or costly, such as iterative agent-environment interactions for robotic policy evaluation or computationally intensive inference steps for finetuning foundation models. Firstly, we introduce Model Predictive Task Sampling (MPTS), a framework that bridges the task space and adaptation risk distributions, providing a theoretical foundation for robust active task sampling. MPTS employs a generative model to characterize the episodic optimization process and predicts task-specific adaptation risk via posterior inference. The resulting risk predictive model amortizes the costly evaluation of task adaptation performance and provably approximates task difficulty rankings. MPTS seamlessly integrates into zero-shot, few-shot, and supervised finetuning settings. Empirically, we conduct extensive experiments in pattern recognition using foundation models and sequential decision-making. Our results demonstrate that MPTS significantly enhances adaptation robustness for tail risk or out-of-distribution (OOD) tasks and improves learning efficiency compared to state-of-the-art (SoTA) methods. The code is available at the project site https://github.com/thu-rllab/MPTS.
Paper Structure (111 sections, 5 theorems, 67 equations, 9 figures, 5 tables, 7 algorithms)

This paper contains 111 sections, 5 theorems, 67 equations, 9 figures, 5 tables, 7 algorithms.

Key Result

Theorem 1

Given arbitrary $K$ data points $\{(\bm\tau_i,\ell(\mathcal{D}_{\tau_i}^{Q},\mathcal{D}_{\tau_i}^{S};\bm\theta_t)\}_{i=1}^{K}$, the adaptation gradient $\nabla_{\bm\theta}\mathcal{\bm L}(\bm\theta_t)$ as a $\sigma$-sub-Gaussian random variable and $\bm\theta_{t+1}=\bm\theta_{t}-\eta_{t}\nabla_{\bm\t when $\eta_{t}\leq\frac{\delta_t}{2G_{t}M_{t}+\sqrt{8\sigma^{2}G_{t}^{2}\ln\left(\frac{K(K-1)}{2\xi

Figures (9)

  • Figure 1: Framework of MPTS in Adaptation Learning.a. The left shows standard random sampling for generating candidate tasks. The middle in blue denotes costly evaluation of $\hat{\mathcal{B}}$ tasks (e.g., agent–environment interaction or foundation model forward pass) in DRM to select Top-$\mathcal{B}$ worst ones. The middle in green depicts MPTS, which predicts task difficulty via a lightweight generative model, avoiding expensive evaluation. The right illustrates the standard optimization pipeline in Meta-Learning, DR, or SFT (Snow: frozen models; Fire: updated models). b. MPTS samples candidate tasks, ranks their difficulty via the risk predictive model's predictions for subset selection, and updates the learner, approximating $\text{CVaR}_{\alpha}$'s Monte Carlo optimization in a predictive manner. The gathered risk signals further update RPMs online. c. The RPM utilizes the risk history $H_{1:t}$ to train under a streaming VI framework. d. The RPM simulates adaptation outcomes $p(\ell\vert\bm\tau,H_{1:t};\bm\theta_{t})$ for $\hat{\mathcal{B}}$ candidate identifiers, computes acquisition scores, and selects the Top-$\mathcal{B}$ identifiers for the $(t+1)$-th iteration.
  • Figure 2: Fundamental Concepts: Task Identifiers, Episodic Learning and Probabilistic Graphical Models.a. The task distribution is uniform and defined over meaningful identifiers $\bm\tau$. For example, the amplitude and the phase $[a,b]$ specifies a sinusoid curve to complete with K-shot observed data points. Robots like Half-Cheetahs are trained to accomplish different locomotion tasks with varying masses and velocities. Some multimodal pattern recognition tasks' identifiers are implicit but can be described from a reference model, e.g., text encoders in CLIP radford2021learning. b. The tail task generalization corresponds to $\text{CVaR}_{\alpha}$, i.e., the integral of tail task risk values in red. In OOD generalization, this work prompt-tunes CLIP on ImageNet russakovsky2015imagenet to test on ImageNet-S wang2019learning. c. Here, the generative model includes grey units as observed variables and white ones as unobservable. The solid directed lines describe the generative modeltomczak2024deep. We use the dash-directed lines to indicate the recognition model and approximate inference within autoencoding variational Bayes kingma2013auto.
  • Figure 3: K-shot Sinusoid Regression Results (10 Runs). Note that a-h are results with MAML as the backbone, while i-l reports the results with Reptile as the backbone. a. Shown are curves of averaged MSEs on the validation task set during meta-training for all methods. b. Curves illustrate the $\text{CVaR}_{0.9}$ MSEs on the validation task set during the meta-training process. c. The meta-trained machine learners are tested on a fixed task set, reporting the average MSEs and CVaR values. d. Displayed are meta-testing results with MPTS machine learners trained by various $\gamma_1/\ \gamma_0$ ratios. e. Displayed are meta-testing results with MPTS machine learners trained in various pseudo batch sizes, i.e., $\hat{\mathcal{B}}=\{1\mathcal{B},2\mathcal{B},4\mathcal{B},8\mathcal{B}\}$. f. The PCC values are tracked during meta-training. g. At a specific iteration, the statistical correlation between predicted and exact adaptation risk values of the task batch is visualized with overall $\rho_{\bar{\ell},\ell}=0.669$. h. The required relative run-time is computed for all methods during meta-training with ERM as the anchor. i. At some meta-training time-step, we visualize the subset selection from the pseudo batch under the RPM. j. We illustrate the curves of averaged MSEs on the validation task set during Reptile meta-training for all relevant baselines. k. We track the corresponding $\text{CVaR}_{0.9}$ MSEs on the validation task set throughout the Reptile meta-training. l. Reported are the tested average MSEs and CVaR values of the meta-trained machine learners on a fixed task set. m. The PCC values are tracked for MPTS during Reptile meta-training. n. We compute the relative run-time for all methods during meta-training with ERM as the anchor.
  • Figure 4: Meta-RL Results on Mujoco Environments (10 Runs).a. The cumulative returns with standard error of means (SEMs) belonging to $\text{CVaR}_{0.9}$ validation MDPs are displayed during meta-training. b. We compute the average cumulative returns with SEMs on validation MDPs during meta-training. c. Tracked are the RPM's PCC values with SEMs over training iterations. d. The relative clock time quantifies the computational complexity for all methods on Walker2dVel, where ERM's runtime works as the anchor. e. We report $\text{CVaR}_{\alpha}$ returns of meta-testing MDPs. f. The box-plot reports results averaged over meta-testing MDPs. g. With PEARL rakelly2019efficient as the Meta-RL backbone, we illustrate the learning curves and meta-testing results on HalfCheetahBody and HalfCheetahMass from RoML greenberg2023train baseline.
  • Figure 5: DR Results on Ergo-Reacher and Lunar-Lander (10 Runs).a. In Ergo-Reacher, the $\text{CVaR}_{0.9}$, $\text{CVaR}_{0.7}$, $\text{CVaR}_{0.5}$ and average cumulative returns on validation MDPs are reported together with the RPM's PCC curve during DR training. b. In Lunar-Lander, the cumulative returns on validation MDPs are illustrated together with the RPM's PCC curve during DR training. c. We test the DR-trained policies on the fixed MDP set and report the $\text{CVaR}_{\alpha}$ cumulative returns. d. The returns averaged over DR-testing MDPs are illustrated. e. The required runtime is computed for all methods on Lunar-Lander. f. In Lunar-Lander, shown are frequencies of sampled identifiers using MPTS during DR training. g. In Lunar-Lander, we test the trained policies in both in-distribution (ID) domains and out-of-distribution (OOD) domains to report each task's average returns.
  • ...and 4 more figures

Theorems & Definitions (9)

  • Definition 1: Conditional Value-at-Risk, CVaRrockafellar2000optimization
  • Definition 2: Model Predictive Task Sampling
  • Theorem 1: Provably Approximately Invariant Task Difficulties
  • Lemma 1: Misranked Subset Quantity
  • Lemma 2: Rank-Preserving Bound in Expectation
  • Lemma 3: Misranking Acute-Angle
  • Theorem 2: Convergence with Diminishing Rank Flipping
  • Definition 3: Permutation Invariant Function
  • Definition 4: Stochastic Risk Function