Table of Contents
Fetching ...

Provable Multi-Task Representation Learning by Two-Layer ReLU Neural Networks

Liam Collins, Hamed Hassani, Mahdi Soltanolkotabi, Aryan Mokhtari, Sanjay Shakkottai

TL;DR

This work proves that nonlinear two-layer ReLU networks pretrained with gradient-based multitask learning can recover a low-dimensional, ground-truth feature subspace shared across many binary classification tasks. The key mechanism is that updating task-specific heads prior to representation learning induces a pseudo-contrastive loss, encouraging alignment of representations for points that share labels across tasks. The authors establish that with $T$ tasks and per-task samples $n_1,n_2$, the first-layer weights concentrate energy in the $r$-dimensional feature space spanned by $\mathbf{M}$, yielding downstream generalization guarantees with complexity independent of the ambient dimension $d$, and they show that single-task pretraining or random features do not offer the same benefits. The results imply substantial practical and theoretical advantages for multi-task pretraining, highlighting the importance of per-task head adaptation and providing a foundation for dimension-robust transfer in nonlinear networks.

Abstract

An increasingly popular machine learning paradigm is to pretrain a neural network (NN) on many tasks offline, then adapt it to downstream tasks, often by re-training only the last linear layer of the network. This approach yields strong downstream performance in a variety of contexts, demonstrating that multitask pretraining leads to effective feature learning. Although several recent theoretical studies have shown that shallow NNs learn meaningful features when either (i) they are trained on a {\em single} task or (ii) they are {\em linear}, very little is known about the closer-to-practice case of {\em nonlinear} NNs trained on {\em multiple} tasks. In this work, we present the first results proving that feature learning occurs during training with a nonlinear model on multiple tasks. Our key insight is that multi-task pretraining induces a pseudo-contrastive loss that favors representations that align points that typically have the same label across tasks. Using this observation, we show that when the tasks are binary classification tasks with labels depending on the projection of the data onto an $r$-dimensional subspace within the $d\gg r$-dimensional input space, a simple gradient-based multitask learning algorithm on a two-layer ReLU NN recovers this projection, allowing for generalization to downstream tasks with sample and neuron complexity independent of $d$. In contrast, we show that with high probability over the draw of a single task, training on this single task cannot guarantee to learn all $r$ ground-truth features.

Provable Multi-Task Representation Learning by Two-Layer ReLU Neural Networks

TL;DR

This work proves that nonlinear two-layer ReLU networks pretrained with gradient-based multitask learning can recover a low-dimensional, ground-truth feature subspace shared across many binary classification tasks. The key mechanism is that updating task-specific heads prior to representation learning induces a pseudo-contrastive loss, encouraging alignment of representations for points that share labels across tasks. The authors establish that with tasks and per-task samples , the first-layer weights concentrate energy in the -dimensional feature space spanned by , yielding downstream generalization guarantees with complexity independent of the ambient dimension , and they show that single-task pretraining or random features do not offer the same benefits. The results imply substantial practical and theoretical advantages for multi-task pretraining, highlighting the importance of per-task head adaptation and providing a foundation for dimension-robust transfer in nonlinear networks.

Abstract

An increasingly popular machine learning paradigm is to pretrain a neural network (NN) on many tasks offline, then adapt it to downstream tasks, often by re-training only the last linear layer of the network. This approach yields strong downstream performance in a variety of contexts, demonstrating that multitask pretraining leads to effective feature learning. Although several recent theoretical studies have shown that shallow NNs learn meaningful features when either (i) they are trained on a {\em single} task or (ii) they are {\em linear}, very little is known about the closer-to-practice case of {\em nonlinear} NNs trained on {\em multiple} tasks. In this work, we present the first results proving that feature learning occurs during training with a nonlinear model on multiple tasks. Our key insight is that multi-task pretraining induces a pseudo-contrastive loss that favors representations that align points that typically have the same label across tasks. Using this observation, we show that when the tasks are binary classification tasks with labels depending on the projection of the data onto an -dimensional subspace within the -dimensional input space, a simple gradient-based multitask learning algorithm on a two-layer ReLU NN recovers this projection, allowing for generalization to downstream tasks with sample and neuron complexity independent of . In contrast, we show that with high probability over the draw of a single task, training on this single task cannot guarantee to learn all ground-truth features.
Paper Structure (20 sections, 33 theorems, 165 equations, 5 figures, 4 tables)

This paper contains 20 sections, 33 theorems, 165 equations, 5 figures, 4 tables.

Key Result

Proposition 3.1

Consider the gradient-based multi-task algorithm described in Section sec:alg that uses $T$ tasks and $(n_1,n_2)$ samples per task to update the (head, representation), respectively, and suppose Assumption assump:diversity holds. Further assumeThe $m=O(d)$ condition in Proposition prop:1 and Theorem

Figures (5)

  • Figure 1: Representation learning error vs training iterations with varying numbers of tasks $T$. Here we sample tasks from the uniform distribution over sparse parity tasks on $r$ binary coordinates determined by $\mathop{\mathrm{sign}}\nolimits(\mathbf{Mx})$ for some row-orthogonal matrix $\mathbf{M}$ and $d$-dimensional input $\mathbf{x}$. Here $d=32$, $r=3$, and the learning model is a two-layer, $m$-neuron ReLU network with first-layer weights $\mathbf{W}\!\in\! \mathbb{R}^{m\times d}$ (the representation). All cases use the same total number of training samples, i.e. the number of samples/task is inversely proportional to the number of training tasks $T$, meaning all cases use the same total number of samples across tasks. Still, as $T$ increases, the row space of $\mathbf{W}$ approaches that of $\mathbf{M}$ (smaller representation learning error). Please see Appendix \ref{['sec:sims']} for details.
  • Figure 2: Representation learning error. (Left) Version of Figure \ref{['fig:multi']} showing the benefit of pretraining with additional tasks that includes the standard deviations (shaded regions) of each statistic around the plotted means over 10 trials. Note that $d=32, r=3$ and all cases use the same total number of samples. (Right) Representation learning error vs number of training iterations when tasks are sampled from either $\mathcal{T}_{\text{s.p.}}$ ('Uniform Distribution') or a skewed distribution over the support of $\mathcal{T}_{\text{s.p.}}$ ('Skewed Distribution'). In this case $d=32, r=4$ and $T=32$.
  • Figure 3: Downstream task performance. (Left) Downstream task performance for multi-task pretrained, single task, and random ('No pretrained') representations $\mathbf{W}$ with varying dimension $d$. Unlike single task pretrained and the non-pretrained representations, the downstream performance of representations trained with multiple tasks does not degrade with $d$. Note that for multi-task, $T=16+d$ and $n_1=n_2=16$ and for single task, $n_1=n_2=16\times(16+d)$, and $r=4$ and $m=16$ in all cases. (Right) Downstream task performance for multi-task-trained representation with $T=32, d=32,r=3,n_1=n_2=16$ and $m=16$, and with $\hat{m}=32$ for downstream linear probing, with varying number of downstream training samples $N$.
  • Figure 4: More tasks isolate important features. From the discussion in Section \ref{['sec:sketch']}, the loss induced by multi-task training with task-specific heads as a function of the representation is approximately $\mathcal{L}(\mathbf{W})\approx -\mathbb{E}_{\mathbf{x},\mathbf{x}'}[\beta(\mathbf{x},\mathbf{x}') \sigma(\mathbf{Wx})^\top \sigma(\mathbf{Wx}') ]$, where $\beta(\mathbf{x},\mathbf{x}')=\mathbb{E}_i[f_i(\mathbf{x})f_i(\mathbf{x}')]$ is the average product of the labels of two points across tasks. This loss is pseudo-contrastive in that it encourages representations of two points to be similar if and only if they share the same label on most tasks ($\beta(\mathbf{x},\mathbf{x}')\approx 1$), which is equivalent to saying that they share important features. Here we consider the gradient of $\mathcal{L}(\mathbf{W})$ with respect to one neuron weight $\mathbf{w}_j$. The gradient takes the form $-\mathbf{A}\mathbf{w}_j$, and we plot finite-task and finite-sample approximations of $\mathbf{A}$. We set $d=16$ and the ground-truth features to be the first $r=4$ coordinates of the data, i.e. $\mathbf{M} = [\mathbf{I}_4, \mathbf{0}_{4\times 12}]$. Roughly speaking, if the finite-task approximation of $\beta(\mathbf{x},\mathbf{x}')$, namely $\frac{1}{T}\sum_{i=1}^T f_i(\mathbf{x})f_i(\mathbf{x}')$, serves as a proxy for whether $\mathbf{x}$ and $\mathbf{x}'$ share ground-truth features, as does $\mathbb{E}_i[f_i(\mathbf{x})f_i(\mathbf{x}')]$, then the terms with $\mathbf{x}$ and $\mathbf{x}'$ having the same ground-truth features will dominate the loss, and these features themselves will dominate $\mathbf{A}$. The above plots confirm this; as the number of tasks $T$ increases and $\frac{1}{T}\sum_{i=1}^T f_i(\mathbf{x})f_i(\mathbf{x}')$ approaches $\mathbb{E}_i[f_i(\mathbf{x})f_i(\mathbf{x}')]$, $\mathbf{A}$ becomes dominated by its top 4-by-4 submatrix, i.e. $\mathbf{A} \approx c \mathbf{M}^\top \mathbf{M}$ for a scalar $c$. So, $\mathbf{A}$ behaves more like a projection onto the row space of $\mathbf{M},$ as desired.
  • Figure 5: (Left) Training with a single head, i.e. $\mathbf{a}_1=\mathbf{a}_2=...=\mathbf{a}_T=\mathbf{a}$, fails to recover the ground-truth representation, as this does not induce an appropriate contrastive loss. (Center) During multi-task pretraining with task-specific heads, the projection of four neurons onto the $r\!=\!2$-dimensional ground-truth subspace fan outwards from the origin such that they remain large and isotropic in this space, whereas (Right) their projections onto the spurious subspace contract towards the origin.

Theorems & Definitions (64)

  • Proposition 3.1
  • Theorem 3.2: Representation Learning
  • Theorem 3.3: End-to-End Guarantee
  • Remark 3.4: Necessity of second layer
  • Remark 3.5: Tightness of exponential complexity in $r$ in positive results
  • Theorem 3.6
  • Theorem 3.7
  • Remark 3.8: Single-task training with highly informative task
  • Lemma A.1
  • proof
  • ...and 54 more