Table of Contents
Fetching ...

Guarantees for Nonlinear Representation Learning: Non-identical Covariates, Dependent Data, Fewer Samples

Thomas T. Zhang, Bruce D. Lee, Ingvar Ziemann, George J. Pappas, Nikolai Matni

TL;DR

As the number of tasks $T$ increases, both the sample requirement and risk bound converge to that of $r$-dimensional regression as if $g_\star$ had been given, and the effect of dependency only enters the sample requirement, leaving the risk bound matching the iid setting.

Abstract

A driving force behind the diverse applicability of modern machine learning is the ability to extract meaningful features across many sources. However, many practical domains involve data that are non-identically distributed across sources, and statistically dependent within its source, violating vital assumptions in existing theoretical studies. Toward addressing these issues, we establish statistical guarantees for learning general $\textit{nonlinear}$ representations from multiple data sources that admit different input distributions and possibly dependent data. Specifically, we study the sample-complexity of learning $T+1$ functions $f_\star^{(t)} \circ g_\star$ from a function class $\mathcal F \times \mathcal G$, where $f_\star^{(t)}$ are task specific linear functions and $g_\star$ is a shared nonlinear representation. A representation $\hat g$ is estimated using $N$ samples from each of $T$ source tasks, and a fine-tuning function $\hat f^{(0)}$ is fit using $N'$ samples from a target task passed through $\hat g$. We show that when $N \gtrsim C_{\mathrm{dep}} (\mathrm{dim}(\mathcal F) + \mathrm{C}(\mathcal G)/T)$, the excess risk of $\hat f^{(0)} \circ \hat g$ on the target task decays as $ν_{\mathrm{div}} \big(\frac{\mathrm{dim}(\mathcal F)}{N'} + \frac{\mathrm{C}(\mathcal G)}{N T} \big)$, where $C_{\mathrm{dep}}$ denotes the effect of data dependency, $ν_{\mathrm{div}}$ denotes an (estimatable) measure of $\textit{task-diversity}$ between the source and target tasks, and $\mathrm C(\mathcal G)$ denotes the complexity of the representation class $\mathcal G$. In particular, our analysis reveals: as the number of tasks $T$ increases, both the sample requirement and risk bound converge to that of $r$-dimensional regression as if $g_\star$ had been given, and the effect of dependency only enters the sample requirement, leaving the risk bound matching the iid setting.

Guarantees for Nonlinear Representation Learning: Non-identical Covariates, Dependent Data, Fewer Samples

TL;DR

As the number of tasks increases, both the sample requirement and risk bound converge to that of -dimensional regression as if had been given, and the effect of dependency only enters the sample requirement, leaving the risk bound matching the iid setting.

Abstract

A driving force behind the diverse applicability of modern machine learning is the ability to extract meaningful features across many sources. However, many practical domains involve data that are non-identically distributed across sources, and statistically dependent within its source, violating vital assumptions in existing theoretical studies. Toward addressing these issues, we establish statistical guarantees for learning general representations from multiple data sources that admit different input distributions and possibly dependent data. Specifically, we study the sample-complexity of learning functions from a function class , where are task specific linear functions and is a shared nonlinear representation. A representation is estimated using samples from each of source tasks, and a fine-tuning function is fit using samples from a target task passed through . We show that when , the excess risk of on the target task decays as , where denotes the effect of data dependency, denotes an (estimatable) measure of between the source and target tasks, and denotes the complexity of the representation class . In particular, our analysis reveals: as the number of tasks increases, both the sample requirement and risk bound converge to that of -dimensional regression as if had been given, and the effect of dependency only enters the sample requirement, leaving the risk bound matching the iid setting.

Paper Structure

This paper contains 22 sections, 27 theorems, 116 equations, 3 figures.

Key Result

Theorem 1.1

Let there be $T$ tasks and $N$ samples per task. Assume $N \geq C_{\mathrm{mix}}{\color{blue}\Omega\left(d_{\mathsf{Y}} r + \mathrm{C}(\mathcal{G})/T\right)}$, where $C_{\mathrm{mix}}$ characterizes the dependency of the covariates of each task. Then the excess transfer risk of ERM is bounded with h where $C_{\mathrm{task\;div}}$ characterizes the relatedness between the source tasks and the targe

Figures (3)

  • Figure 1: \ref{['fig: example obs']} shows an example camera observation of the pybullet simulated cartpole environment. In this image, the cartpole is at the state $x = 0000 ^\top$. \ref{['fig: ideal keypoints']} illustrates the ideal keypoints extracted from a cartpole image.
  • Figure 2: Three evaluation metrics comparing the performance of multi-task versus single-task imitation learning: the MSE between the input of the learned and expert controllers when evaluated on the expert trajectory, the deviation between the state trajectories generated by the learned and expert controllers, and the %trials that the learned controller keeps the pole balanced for all $500$ timesteps (the dynamics are discretized to $\Delta t = 0.02$ seconds). Three curves are shown for multi-task imitation learning, generated by pre-training with a different number of source tasks. In all metrics, multi-task learning improves over single task when few target trajectories are available.
  • Figure 3: Input imitation error of the policies trained with a shared representation plotted against the number of source tasks used to train the representation on a $\log\log$ scale. The number of target trajectories used for finetuning is fixed at $100$.

Theorems & Definitions (47)

  • Theorem 1.1: Main result, informal
  • Definition 2.1: Task-Diversity tripuraneni2020theory
  • Lemma 2.2
  • Definition 2.4
  • Remark 2.5
  • Proposition 2.5: \ref{['eq: mu task coverage']}$\implies$\ref{['eq: nu task diversity']}
  • Remark 2.6: Robustness to overspecified $r$
  • Definition 2.7
  • Proposition 2.7: Convergence of $\hat\nu_N(g)$
  • Remark 2.8
  • ...and 37 more