Table of Contents
Fetching ...

InterroGate: Learning to Share, Specialize, and Prune Representations for Multi-task Learning

Babak Ehteshami Bejnordi, Gaurav Kumar, Amelie Royer, Christos Louizos, Tijmen Blankevoort, Mohsen Ghafoorian

TL;DR

InterroGate addresses task interference in multi-task learning by learning per-layer gates that decide between shared and task-specific representations for all tasks, with a sparsity-based regularizer to meet a compute budget. Gates are trained end-to-end and then fixed at inference, enabling pruning of unused parameters so all tasks can be predicted in a single forward pass with a static, efficient architecture. The approach achieves state-of-the-art or competitive results on CelebA, NYUD-v2, and PASCAL-Context across CNN and transformer backbones, while offering a controllable trade-off between accuracy and FLOPs. This makes InterroGate practically valuable for edge devices and real-time systems requiring efficient, scalable multi-task inference.

Abstract

Jointly learning multiple tasks with a unified model can improve accuracy and data efficiency, but it faces the challenge of task interference, where optimizing one task objective may inadvertently compromise the performance of another. A solution to mitigate this issue is to allocate task-specific parameters, free from interference, on top of shared features. However, manually designing such architectures is cumbersome, as practitioners need to balance between the overall performance across all tasks and the higher computational cost induced by the newly added parameters. In this work, we propose \textit{InterroGate}, a novel multi-task learning (MTL) architecture designed to mitigate task interference while optimizing inference computational efficiency. We employ a learnable gating mechanism to automatically balance the shared and task-specific representations while preserving the performance of all tasks. Crucially, the patterns of parameter sharing and specialization dynamically learned during training, become fixed at inference, resulting in a static, optimized MTL architecture. Through extensive empirical evaluations, we demonstrate SoTA results on three MTL benchmarks using convolutional as well as transformer-based backbones on CelebA, NYUD-v2, and PASCAL-Context.

InterroGate: Learning to Share, Specialize, and Prune Representations for Multi-task Learning

TL;DR

InterroGate addresses task interference in multi-task learning by learning per-layer gates that decide between shared and task-specific representations for all tasks, with a sparsity-based regularizer to meet a compute budget. Gates are trained end-to-end and then fixed at inference, enabling pruning of unused parameters so all tasks can be predicted in a single forward pass with a static, efficient architecture. The approach achieves state-of-the-art or competitive results on CelebA, NYUD-v2, and PASCAL-Context across CNN and transformer backbones, while offering a controllable trade-off between accuracy and FLOPs. This makes InterroGate practically valuable for edge devices and real-time systems requiring efficient, scalable multi-task inference.

Abstract

Jointly learning multiple tasks with a unified model can improve accuracy and data efficiency, but it faces the challenge of task interference, where optimizing one task objective may inadvertently compromise the performance of another. A solution to mitigate this issue is to allocate task-specific parameters, free from interference, on top of shared features. However, manually designing such architectures is cumbersome, as practitioners need to balance between the overall performance across all tasks and the higher computational cost induced by the newly added parameters. In this work, we propose \textit{InterroGate}, a novel multi-task learning (MTL) architecture designed to mitigate task interference while optimizing inference computational efficiency. We employ a learnable gating mechanism to automatically balance the shared and task-specific representations while preserving the performance of all tasks. Crucially, the patterns of parameter sharing and specialization dynamically learned during training, become fixed at inference, resulting in a static, optimized MTL architecture. Through extensive empirical evaluations, we demonstrate SoTA results on three MTL benchmarks using convolutional as well as transformer-based backbones on CelebA, NYUD-v2, and PASCAL-Context.
Paper Structure (32 sections, 7 equations, 6 figures, 11 tables, 1 algorithm)

This paper contains 32 sections, 7 equations, 6 figures, 11 tables, 1 algorithm.

Figures (6)

  • Figure 1: Overview of the proposed InterroGate framework: The original encoder layers are substituted with InterroGate layers. The input to the layer is $t+1$ feature maps, one shared representation and $t$ task-specific representations. To decide between shared $\psi^\ell$ or task-specific $\varphi_t^\ell$ features, each task relies on its own gating module $G_t^\ell$. The resulting channel-mixed feature-map $\varphi_t^{\prime \ell}$ is then fed to the next task-specific layer. The input to the shared branch for the next layer is constructed by linearly combining the task-specific features of all tasks using the learned parameter $\beta_t^{\ell}$. During inference, the parameters (shared or task-specific) that are not chosen by the gates are removed from the model, resulting in a plain neural network architecture.
  • Figure 2: Accuracy vs. floating-point operations (FLOP) trade-off curves for InterroGate and SoTA MTL methods. (a) Results on CelebA using ResNet-20 backbone at three different widths (Original, Half, and Quarter). (b) NYUD-v2 using HRNet-18 backbone, and (c) ResNet-18 on PASCAL-Context. To avoid clutter, we present the six highest-performing MTL baselines in (b) & (c). The single task baseline in (b) has 65.1 GFLOPs.
  • Figure 3: The task-specific representation selection ratio (top) versus proportions of maximum contributions to the shared branch (bottom) for InterroGate with hinge loss (left), $L_1$ loss with medium pruning (middle) and $L_1$ loss with high pruning (right).
  • Figure 4: Sharing and specialization patterns on pascal context dataset with ResNet-18 backbone.
  • Figure 5: Sweeping over different $\{\tau_t\}$ on the NYUD-v2 experiments with HRNet-18 backbone. We plot the MTL performance $\Delta_{MTL}$ against the total number of FLOPs, then color each scatter point by the value of $\tau_t$ when the task $t$ is (a) segmentation, (b) depth and (c) normals.
  • ...and 1 more figures