Table of Contents
Fetching ...

Task-Aware Harmony Multi-Task Decision Transformer for Offline Reinforcement Learning

Ziqing Fan, Shengchao Hu, Yuhang Zhou, Li Shen, Ya Zhang, Yanfeng Wang, Dacheng Tao

TL;DR

This work introduces the Harmony Multi-Task Decision Transformer (HarmoDT), a novel solution designed to identify an optimal harmony subspace of parameters for each task, and designs a group-wise variant (G-HarmoDT) that clusters tasks into coherent groups based on gradient information.

Abstract

The purpose of offline multi-task reinforcement learning (MTRL) is to develop a unified policy applicable to diverse tasks without the need for online environmental interaction. Recent advancements approach this through sequence modeling, leveraging the Transformer architecture's scalability and the benefits of parameter sharing to exploit task similarities. However, variations in task content and complexity pose significant challenges in policy formulation, necessitating judicious parameter sharing and management of conflicting gradients for optimal policy performance. Furthermore, identifying the optimal parameter subspace for each task often necessitates prior knowledge of the task identifier during inference, limiting applicability in real-world scenarios with variable task content and unknown current tasks. In this work, we introduce the Harmony Multi-Task Decision Transformer (HarmoDT), a novel solution designed to identify an optimal harmony subspace of parameters for each task. We formulate this as a bi-level optimization problem within a meta-learning framework, where the upper level learns masks to define the harmony subspace, while the inner level focuses on updating parameters to improve the overall performance of the unified policy. To eliminate the need for task identifiers, we further design a group-wise variant (G-HarmoDT) that clusters tasks into coherent groups based on gradient information, and utilizes a gating network to determine task identifiers during inference. Empirical evaluations across various benchmarks highlight the superiority of our approach, demonstrating its effectiveness in the multi-task context with specific improvements of 8% gain in task-provided settings, 5% in task-agnostic settings, and 10% in unseen settings.

Task-Aware Harmony Multi-Task Decision Transformer for Offline Reinforcement Learning

TL;DR

This work introduces the Harmony Multi-Task Decision Transformer (HarmoDT), a novel solution designed to identify an optimal harmony subspace of parameters for each task, and designs a group-wise variant (G-HarmoDT) that clusters tasks into coherent groups based on gradient information.

Abstract

The purpose of offline multi-task reinforcement learning (MTRL) is to develop a unified policy applicable to diverse tasks without the need for online environmental interaction. Recent advancements approach this through sequence modeling, leveraging the Transformer architecture's scalability and the benefits of parameter sharing to exploit task similarities. However, variations in task content and complexity pose significant challenges in policy formulation, necessitating judicious parameter sharing and management of conflicting gradients for optimal policy performance. Furthermore, identifying the optimal parameter subspace for each task often necessitates prior knowledge of the task identifier during inference, limiting applicability in real-world scenarios with variable task content and unknown current tasks. In this work, we introduce the Harmony Multi-Task Decision Transformer (HarmoDT), a novel solution designed to identify an optimal harmony subspace of parameters for each task. We formulate this as a bi-level optimization problem within a meta-learning framework, where the upper level learns masks to define the harmony subspace, while the inner level focuses on updating parameters to improve the overall performance of the unified policy. To eliminate the need for task identifiers, we further design a group-wise variant (G-HarmoDT) that clusters tasks into coherent groups based on gradient information, and utilizes a gating network to determine task identifiers during inference. Empirical evaluations across various benchmarks highlight the superiority of our approach, demonstrating its effectiveness in the multi-task context with specific improvements of 8% gain in task-provided settings, 5% in task-agnostic settings, and 10% in unseen settings.

Paper Structure

This paper contains 47 sections, 21 equations, 7 figures, 9 tables, 5 algorithms.

Figures (7)

  • Figure 1: Accuracy drops as the number of tasks increases and task identifiers are absent in near-optimal cases within the Meta-World benchmark, focusing on a comparison of our method with prevalent MTRL algorithms, PromptDT and MTDIFF. More results of baselines and analysis refer to Section \ref{['sec:exp']}. Methods including HarmoDT face significant performance drops without task identifiers, whereas our G-HarmoDT effectively overcomes this limitation.
  • Figure 2: Illustration of the averaged harmony score among trainable weights during training for policies with and without randomly initialized masks (left panel), and success rates with varying mask sparsity levels (right panel) in the Meta-World benchmark. The averaged harmony score, defined in Section \ref{['sec:conflict']}, reflects better harmony with higher values.
  • Figure 3: (a) Accuracy of the gating module in classifying tasks into groups under 50 tasks of Meta-World (e.g., 5 groups mean each group contains 10 tasks). The accuracy of the gating network decreases significantly as the number of groups increases. (b) Performance in task-agnostic settings with different numbers of groups and the corresponding gating module under Meta-World benchmark. Directly combining the gating network with HarmoDT results in sub-optimal performance. For example, in MT50 with 50 tasks, 10Group (G-HarmoDT) achieves the best performance instead of 50Group (HarmoDT). An in-depth analysis is provided in Section \ref{['sec:exp']}.
  • Figure 4: Illustration of the conflicting problem and the framework of HarmoDT to find a harmony subspace for each task. The left panel shows the conflicting phenomenon reflected by divergent task-specific gradients. The middle panel illustrates the procedure to find a harmony subspace for each task via the mask learning. The right panel demonstrates the workflow of HarmoDT based on the DT architecture with prompts when handling a task, such as ${\mathcal{T}}_3$.
  • Figure 5: Overall framework of the G-HarmoDT: In the first stage, we warm up the weights and group tasks based on weight evaluations. In the second stage, we train a gating model using task inputs and ID-group mappings from the grouping module, while simultaneously updating the main model. In the third stage, we gradually update the group subspaces. During inference, the gating model provides the appropriate group ID and mask for the unknown task.
  • ...and 2 more figures

Theorems & Definitions (6)

  • Definition 3.1: Harmony Score on a Single Weight
  • Definition 3.2: Averaged Harmony Score
  • Definition 4.1: Agreement Score
  • Definition 4.2: Importance Score
  • Definition 4.3: Group Agreement Score
  • Definition 4.4: Group Importance Score