Table of Contents
Fetching ...

Learning What Helps: Task-Aligned Context Selection for Vision Tasks

Jingyu Guo, Emir Konuk, Fredrik Strand, Christos Matsoukas, Kevin Smith

TL;DR

Task-Aligned Context Selection (TACS) enables discriminative vision models to learn which contextual examples truly improve downstream performance by integrating retrieval into the learning objective. It couples a selector with a Downstream Task Network and trains them via a hybrid optimization that combines differentiable relaxation with reinforcement learning. Across 18 diverse datasets, including fine-grained natural images and medical imaging, TACS yields consistent gains over similarity-based retrieval, with notable benefits in data-limited and ambiguous settings and interpretable contextual reasoning patterns. This work reframes retrieval as a task-aware, adaptive component of vision systems, with potential implications for data curation, active learning, and domain-specific visual reasoning.

Abstract

Humans often resolve visual uncertainty by comparing an image with relevant examples, but ViTs lack the ability to identify which examples would improve their predictions. We present Task-Aligned Context Selection (TACS), a framework that learns to select paired examples which truly improve task performance rather than those that merely appear similar. TACS jointly trains a selector network with the task model through a hybrid optimization scheme combining gradient-based supervision and reinforcement learning, making retrieval part of the learning objective. By aligning selection with task rewards, TACS enables discriminative models to discover which contextual examples genuinely help. Across 18 datasets covering fine-grained recognition, medical image classification, and medical image segmentation, TACS consistently outperforms similarity-based retrieval, particularly in challenging or data-limited settings.

Learning What Helps: Task-Aligned Context Selection for Vision Tasks

TL;DR

Task-Aligned Context Selection (TACS) enables discriminative vision models to learn which contextual examples truly improve downstream performance by integrating retrieval into the learning objective. It couples a selector with a Downstream Task Network and trains them via a hybrid optimization that combines differentiable relaxation with reinforcement learning. Across 18 diverse datasets, including fine-grained natural images and medical imaging, TACS yields consistent gains over similarity-based retrieval, with notable benefits in data-limited and ambiguous settings and interpretable contextual reasoning patterns. This work reframes retrieval as a task-aware, adaptive component of vision systems, with potential implications for data curation, active learning, and domain-specific visual reasoning.

Abstract

Humans often resolve visual uncertainty by comparing an image with relevant examples, but ViTs lack the ability to identify which examples would improve their predictions. We present Task-Aligned Context Selection (TACS), a framework that learns to select paired examples which truly improve task performance rather than those that merely appear similar. TACS jointly trains a selector network with the task model through a hybrid optimization scheme combining gradient-based supervision and reinforcement learning, making retrieval part of the learning objective. By aligning selection with task rewards, TACS enables discriminative models to discover which contextual examples genuinely help. Across 18 datasets covering fine-grained recognition, medical image classification, and medical image segmentation, TACS consistently outperforms similarity-based retrieval, particularly in challenging or data-limited settings.

Paper Structure

This paper contains 28 sections, 10 equations, 4 figures, 5 tables.

Figures (4)

  • Figure 1: Overview of Task-Aligned Context Selection (TACS). Given a query image, TACS learns to select the most informative example from a candidate pool to form a task-aligned input pair. A selector network predicts which candidate provides the most useful context for the downstream task network, (e.g. classifier or segmentor) which operates on the selected pair. By learning to retrieve helpful rather than merely similar examples, TACS improves decision accuracy and robustness, particularly in ambiguous cases.
  • Figure 2: Architecture and training flow of Task-Aligned Context Selection (TACS). (a) The Selector processes a query image $x_q$ and candidate pool ${x_c^i}$ and selects the most helpful sample. During training, the selection is optimized through two complementary optimization paths (b). The differentiable Sampling path (blue) uses the straight through Gumbel-Softmax reparameterization on the utility scores $s_c$ to implement a categorical selection of candidates, enabling end-to-end gradient flow via the task loss $\mathcal{L}_{\text{grad}}$. The policy path (red) samples a discrete action based on $s_c$ and updates the Selector using policy gradients with a task-aligned reward $r(o,a)$ derived from downstream performance. Gradients and rewards jointly update shared Selector parameters. (c–d) Examples of downstream tasks: classification and segmentation. The combined loss enables stable yet decisive task-aligned selection.
  • Figure 3: Statistical analysis of selected context pairs. Comparison of cross-class selection rates and mean LPIPS distances for pairs chosen by different retrieval strategies. TACS consistently selects a higher proportion of cross-class and perceptually diverse (higher LPIPS) candidates than a fixed retriever, particularly on complex datasets such as DTD and SUN397. This indicates that the learned policy favors complementary rather than merely similar examples, enabling richer contextual reasoning.
  • Figure 4: For each example, we show (1) the query image, (2) the TACS classifier’s attention on the query image (which also receives the selected context image), (3) the no-context classifier’s attention on the same image, (4) the context image selected by the Selector, and (5) the Selector’s attention on the selected image. TACS consistently focuses the classifier on more discriminative regions than the no-context model. The Selector attends to broad global structures, often choosing complementary rather than visually similar images (e.g., retrieving a piano to help disambiguate headphones), enabling the classifier to refine fine-grained decisions through contrasting contextual cues.