Table of Contents
Fetching ...

Make the Pertinent Salient: Task-Relevant Reconstruction for Visual Control with Distractions

Kyungmin Kim, JB Lanier, Pierre Baldi, Charless Fowlkes, Roy Fox

TL;DR

Segmentation Dreamer is proposed, a simple yet effective auxiliary task to facilitate representation learning in distracting environments that greatly reduces the complexity of representation learning by removing the need to encode task-irrelevant objects in the latent representation.

Abstract

Recent advancements in Model-Based Reinforcement Learning (MBRL) have made it a powerful tool for visual control tasks. Despite improved data efficiency, it remains challenging to train MBRL agents with generalizable perception. Training in the presence of visual distractions is particularly difficult due to the high variation they introduce to representation learning. Building on DREAMER, a popular MBRL method, we propose a simple yet effective auxiliary task to facilitate representation learning in distracting environments. Under the assumption that task-relevant components of image observations are straightforward to identify with prior knowledge in a given task, we use a segmentation mask on image observations to only reconstruct task-relevant components. In doing so, we greatly reduce the complexity of representation learning by removing the need to encode task-irrelevant objects in the latent representation. Our method, Segmentation Dreamer (SD), can be used either with ground-truth masks easily accessible in simulation or by leveraging potentially imperfect segmentation foundation models. The latter is further improved by selectively applying the reconstruction loss to avoid providing misleading learning signals due to mask prediction errors. In modified DeepMind Control suite (DMC) and Meta-World tasks with added visual distractions, SD achieves significantly better sample efficiency and greater final performance than prior work. We find that SD is especially helpful in sparse reward tasks otherwise unsolvable by prior work, enabling the training of visually robust agents without the need for extensive reward engineering.

Make the Pertinent Salient: Task-Relevant Reconstruction for Visual Control with Distractions

TL;DR

Segmentation Dreamer is proposed, a simple yet effective auxiliary task to facilitate representation learning in distracting environments that greatly reduces the complexity of representation learning by removing the need to encode task-irrelevant objects in the latent representation.

Abstract

Recent advancements in Model-Based Reinforcement Learning (MBRL) have made it a powerful tool for visual control tasks. Despite improved data efficiency, it remains challenging to train MBRL agents with generalizable perception. Training in the presence of visual distractions is particularly difficult due to the high variation they introduce to representation learning. Building on DREAMER, a popular MBRL method, we propose a simple yet effective auxiliary task to facilitate representation learning in distracting environments. Under the assumption that task-relevant components of image observations are straightforward to identify with prior knowledge in a given task, we use a segmentation mask on image observations to only reconstruct task-relevant components. In doing so, we greatly reduce the complexity of representation learning by removing the need to encode task-irrelevant objects in the latent representation. Our method, Segmentation Dreamer (SD), can be used either with ground-truth masks easily accessible in simulation or by leveraging potentially imperfect segmentation foundation models. The latter is further improved by selectively applying the reconstruction loss to avoid providing misleading learning signals due to mask prediction errors. In modified DeepMind Control suite (DMC) and Meta-World tasks with added visual distractions, SD achieves significantly better sample efficiency and greater final performance than prior work. We find that SD is especially helpful in sparse reward tasks otherwise unsolvable by prior work, enabling the training of visually robust agents without the need for extensive reward engineering.

Paper Structure

This paper contains 35 sections, 4 equations, 11 figures, 2 tables.

Figures (11)

  • Figure 1: (Left) Providing mask example(s) and fine-tuning a mask model, or instrumenting a simulator, to obtain masks. (Right) An input observation in a distracting Meta-World with three alternative auxiliary task targets. Moving scenes in the background are considered distractions. (b) Observations including task-irrelevant information, disturbing world-model training. (c) and (d) Segmentation of task-relevant components using, respectively, a ground-truth mask and an approximate mask generated by segmentation models.
  • Figure 2: Filtering $L_2$ loss to avoid training on false negatives in RGB labels. (Left) Estimated pixel locations (f) where the RGB target (c) is likely incorrectly masked out by the segmentation model (e). (Right) A world model equipped with two decoders, one for reconstructing task-relevant masked RGB images and the other for binary masks, the targets for which are generated by a segmentation model. RGB $L_2$ loss is selectively masked by the set difference between (d) and (e). Latent representations ($x_t$) in the world model are subjected to the training signal only from the RGB branch. The binary branch is only utilized for selective $L_2$ loss.
  • Figure 3: (a) Learning curves on six visual control tasks from DMC. Every method but Dreamer* is trained on distracting environments. All curves show the mean over 4 seeds with the standard error of the mean (SEM) shaded. (b) Segmentation quality during training vs. downstream task performance. Best viewed in color.
  • Figure 4: (a)+(b) Qualitative comparison of SD trained with naive and selective $L_2$ loss. Trajectories are taken from each method's train-time replay buffer, selected to have the same background. Frames with PerSAM error are highlighted. The model trained with the selective $L_2$ loss overcomes errors in the target, whereas the one trained with the naive $L_2$ loss memorizes target errors. (c)+(d) shows the precision and recall of PerSAM and the SD RGB decoder prediction. SD RGB predictions are binarized using a threshold to compute recall and precision w.r.t. the ground-truth mask. The data points used for plotting are from the same Cheetah Run training experiment as in (a)+(b). The selective $L_2$ loss significantly improves the recall with only a moderate impact on precision.
  • Figure 5: Learning curves on six visual robotic manipulation tasks from Meta-World. All curves show the mean over 4 seeds with the standard error of the mean shaded.
  • ...and 6 more figures