Few-Shot Learning via Learning the Representation, Provably
Simon S. Du, Wei Hu, Sham M. Kakade, Jason D. Lee, Qi Lei
TL;DR
This work provides the first theory showing that representation learning can fully exploit all source-task data to enable few-shot learning on a target task. By replacing the i.i.d. task assumption with structural conditions on input covariances and task diversity, they derive fast excess-risk rates in several regimes: (i) low-dimensional linear representations with $O\left(\frac{dk}{n_1T}+\frac{k}{n_2}\right)$-type bounds, (ii) general nonlinear low-dimensional representations with data-dependent complexity bounds, (iii) high-dimensional linear representations with nuclear-norm regularization yielding rates tied to $\mathrm{Tr}(\Sigma)$ and $\|\Sigma\|_2$, and (iv) neural networks (including two-layer ReLUs) with similar pooling advantages. In all cases, the $n_1T$ source samples are effectively pooled to learn a representation that substantially reduces target-sample requirements, surpassing the standard $\frac{d}{n_2}$ baseline. This work thereby furnishes a principled meta-learning and few-shot framework with provable guarantees for both linear and nonlinear representations and neural architectures.
Abstract
This paper studies few-shot learning via representation learning, where one uses $T$ source tasks with $n_1$ data per task to learn a representation in order to reduce the sample complexity of a target task for which there is only $n_2 (\ll n_1)$ data. Specifically, we focus on the setting where there exists a good \emph{common representation} between source and target, and our goal is to understand how much of a sample size reduction is possible. First, we study the setting where this common representation is low-dimensional and provide a fast rate of $O\left(\frac{\mathcal{C}\left(Φ\right)}{n_1T} + \frac{k}{n_2}\right)$; here, $Φ$ is the representation function class, $\mathcal{C}\left(Φ\right)$ is its complexity measure, and $k$ is the dimension of the representation. When specialized to linear representation functions, this rate becomes $O\left(\frac{dk}{n_1T} + \frac{k}{n_2}\right)$ where $d (\gg k)$ is the ambient input dimension, which is a substantial improvement over the rate without using representation learning, i.e. over the rate of $O\left(\frac{d}{n_2}\right)$. This result bypasses the $Ω(\frac{1}{T})$ barrier under the i.i.d. task assumption, and can capture the desired property that all $n_1T$ samples from source tasks can be \emph{pooled} together for representation learning. Next, we consider the setting where the common representation may be high-dimensional but is capacity-constrained (say in norm); here, we again demonstrate the advantage of representation learning in both high-dimensional linear regression and neural network learning. Our results demonstrate representation learning can fully utilize all $n_1T$ samples from source tasks.
