Table of Contents
Fetching ...

Few-Shot Learning via Learning the Representation, Provably

Simon S. Du, Wei Hu, Sham M. Kakade, Jason D. Lee, Qi Lei

TL;DR

This work provides the first theory showing that representation learning can fully exploit all source-task data to enable few-shot learning on a target task. By replacing the i.i.d. task assumption with structural conditions on input covariances and task diversity, they derive fast excess-risk rates in several regimes: (i) low-dimensional linear representations with $O\left(\frac{dk}{n_1T}+\frac{k}{n_2}\right)$-type bounds, (ii) general nonlinear low-dimensional representations with data-dependent complexity bounds, (iii) high-dimensional linear representations with nuclear-norm regularization yielding rates tied to $\mathrm{Tr}(\Sigma)$ and $\|\Sigma\|_2$, and (iv) neural networks (including two-layer ReLUs) with similar pooling advantages. In all cases, the $n_1T$ source samples are effectively pooled to learn a representation that substantially reduces target-sample requirements, surpassing the standard $\frac{d}{n_2}$ baseline. This work thereby furnishes a principled meta-learning and few-shot framework with provable guarantees for both linear and nonlinear representations and neural architectures.

Abstract

This paper studies few-shot learning via representation learning, where one uses $T$ source tasks with $n_1$ data per task to learn a representation in order to reduce the sample complexity of a target task for which there is only $n_2 (\ll n_1)$ data. Specifically, we focus on the setting where there exists a good \emph{common representation} between source and target, and our goal is to understand how much of a sample size reduction is possible. First, we study the setting where this common representation is low-dimensional and provide a fast rate of $O\left(\frac{\mathcal{C}\left(Φ\right)}{n_1T} + \frac{k}{n_2}\right)$; here, $Φ$ is the representation function class, $\mathcal{C}\left(Φ\right)$ is its complexity measure, and $k$ is the dimension of the representation. When specialized to linear representation functions, this rate becomes $O\left(\frac{dk}{n_1T} + \frac{k}{n_2}\right)$ where $d (\gg k)$ is the ambient input dimension, which is a substantial improvement over the rate without using representation learning, i.e. over the rate of $O\left(\frac{d}{n_2}\right)$. This result bypasses the $Ω(\frac{1}{T})$ barrier under the i.i.d. task assumption, and can capture the desired property that all $n_1T$ samples from source tasks can be \emph{pooled} together for representation learning. Next, we consider the setting where the common representation may be high-dimensional but is capacity-constrained (say in norm); here, we again demonstrate the advantage of representation learning in both high-dimensional linear regression and neural network learning. Our results demonstrate representation learning can fully utilize all $n_1T$ samples from source tasks.

Few-Shot Learning via Learning the Representation, Provably

TL;DR

This work provides the first theory showing that representation learning can fully exploit all source-task data to enable few-shot learning on a target task. By replacing the i.i.d. task assumption with structural conditions on input covariances and task diversity, they derive fast excess-risk rates in several regimes: (i) low-dimensional linear representations with -type bounds, (ii) general nonlinear low-dimensional representations with data-dependent complexity bounds, (iii) high-dimensional linear representations with nuclear-norm regularization yielding rates tied to and , and (iv) neural networks (including two-layer ReLUs) with similar pooling advantages. In all cases, the source samples are effectively pooled to learn a representation that substantially reduces target-sample requirements, surpassing the standard baseline. This work thereby furnishes a principled meta-learning and few-shot framework with provable guarantees for both linear and nonlinear representations and neural architectures.

Abstract

This paper studies few-shot learning via representation learning, where one uses source tasks with data per task to learn a representation in order to reduce the sample complexity of a target task for which there is only data. Specifically, we focus on the setting where there exists a good \emph{common representation} between source and target, and our goal is to understand how much of a sample size reduction is possible. First, we study the setting where this common representation is low-dimensional and provide a fast rate of ; here, is the representation function class, is its complexity measure, and is the dimension of the representation. When specialized to linear representation functions, this rate becomes where is the ambient input dimension, which is a substantial improvement over the rate without using representation learning, i.e. over the rate of . This result bypasses the barrier under the i.i.d. task assumption, and can capture the desired property that all samples from source tasks can be \emph{pooled} together for representation learning. Next, we consider the setting where the common representation may be high-dimensional but is capacity-constrained (say in norm); here, we again demonstrate the advantage of representation learning in both high-dimensional linear regression and neural network learning. Our results demonstrate representation learning can fully utilize all samples from source tasks.

Paper Structure

This paper contains 16 sections, 14 theorems, 120 equations.

Key Result

Theorem 4.1

Fix a failure probability $\delta\in(0, 1)$. Under Assumptions assump:linear_subgaussian, assump:linear_covariance_dominance, assump:linear_diverse_source and assump:wT+1, we further assume $2k\le \min\{ d, T\}$ and that the sample sizes in source and target tasks satisfy $n_1 \gg \rho^4(d+\log\frac

Theorems & Definitions (56)

  • Theorem 4.1: main theorem for linear representations
  • Remark 4.1: multi-class problems
  • Remark 4.2: deterministic target task
  • Definition 5.1: Gaussian width
  • Definition 5.2: covariance between two representations
  • Theorem 5.1: main theorem for general representations
  • Claim 5.3: analogue of Claim \ref{['claim:linear_training_guarantee']}
  • proof
  • Theorem 6.1: main theorem for high-dimensional representations
  • Remark 6.1: The non-convex landscape
  • ...and 46 more