InterroGate: Learning to Share, Specialize, and Prune Representations for Multi-task Learning
Babak Ehteshami Bejnordi, Gaurav Kumar, Amelie Royer, Christos Louizos, Tijmen Blankevoort, Mohsen Ghafoorian
TL;DR
InterroGate addresses task interference in multi-task learning by learning per-layer gates that decide between shared and task-specific representations for all tasks, with a sparsity-based regularizer to meet a compute budget. Gates are trained end-to-end and then fixed at inference, enabling pruning of unused parameters so all tasks can be predicted in a single forward pass with a static, efficient architecture. The approach achieves state-of-the-art or competitive results on CelebA, NYUD-v2, and PASCAL-Context across CNN and transformer backbones, while offering a controllable trade-off between accuracy and FLOPs. This makes InterroGate practically valuable for edge devices and real-time systems requiring efficient, scalable multi-task inference.
Abstract
Jointly learning multiple tasks with a unified model can improve accuracy and data efficiency, but it faces the challenge of task interference, where optimizing one task objective may inadvertently compromise the performance of another. A solution to mitigate this issue is to allocate task-specific parameters, free from interference, on top of shared features. However, manually designing such architectures is cumbersome, as practitioners need to balance between the overall performance across all tasks and the higher computational cost induced by the newly added parameters. In this work, we propose \textit{InterroGate}, a novel multi-task learning (MTL) architecture designed to mitigate task interference while optimizing inference computational efficiency. We employ a learnable gating mechanism to automatically balance the shared and task-specific representations while preserving the performance of all tasks. Crucially, the patterns of parameter sharing and specialization dynamically learned during training, become fixed at inference, resulting in a static, optimized MTL architecture. Through extensive empirical evaluations, we demonstrate SoTA results on three MTL benchmarks using convolutional as well as transformer-based backbones on CelebA, NYUD-v2, and PASCAL-Context.
