Provable Multi-Task Representation Learning by Two-Layer ReLU Neural Networks
Liam Collins, Hamed Hassani, Mahdi Soltanolkotabi, Aryan Mokhtari, Sanjay Shakkottai
TL;DR
This work proves that nonlinear two-layer ReLU networks pretrained with gradient-based multitask learning can recover a low-dimensional, ground-truth feature subspace shared across many binary classification tasks. The key mechanism is that updating task-specific heads prior to representation learning induces a pseudo-contrastive loss, encouraging alignment of representations for points that share labels across tasks. The authors establish that with $T$ tasks and per-task samples $n_1,n_2$, the first-layer weights concentrate energy in the $r$-dimensional feature space spanned by $\mathbf{M}$, yielding downstream generalization guarantees with complexity independent of the ambient dimension $d$, and they show that single-task pretraining or random features do not offer the same benefits. The results imply substantial practical and theoretical advantages for multi-task pretraining, highlighting the importance of per-task head adaptation and providing a foundation for dimension-robust transfer in nonlinear networks.
Abstract
An increasingly popular machine learning paradigm is to pretrain a neural network (NN) on many tasks offline, then adapt it to downstream tasks, often by re-training only the last linear layer of the network. This approach yields strong downstream performance in a variety of contexts, demonstrating that multitask pretraining leads to effective feature learning. Although several recent theoretical studies have shown that shallow NNs learn meaningful features when either (i) they are trained on a {\em single} task or (ii) they are {\em linear}, very little is known about the closer-to-practice case of {\em nonlinear} NNs trained on {\em multiple} tasks. In this work, we present the first results proving that feature learning occurs during training with a nonlinear model on multiple tasks. Our key insight is that multi-task pretraining induces a pseudo-contrastive loss that favors representations that align points that typically have the same label across tasks. Using this observation, we show that when the tasks are binary classification tasks with labels depending on the projection of the data onto an $r$-dimensional subspace within the $d\gg r$-dimensional input space, a simple gradient-based multitask learning algorithm on a two-layer ReLU NN recovers this projection, allowing for generalization to downstream tasks with sample and neuron complexity independent of $d$. In contrast, we show that with high probability over the draw of a single task, training on this single task cannot guarantee to learn all $r$ ground-truth features.
