An analytic theory of generalization dynamics and transfer learning in deep linear networks
Andrew K. Lampinen, Surya Ganguli
TL;DR
This work develops an analytic theory of generalization dynamics for deep linear networks within a teacher-student, low-rank/noise framework. It yields closed-form training and test error expressions as functions of time, data size, initialization, and task structure, revealing that early learning prioritizes high-SNR task structure and can be largely independent of network size. A phase transition in how teacher information imprints onto training data and a non-gradient shrinkage algorithm are derived, along with a theory for transfer learning driven by input-feature alignment. The results extend, qualitatively, to nonlinear networks and deeper architectures, offering insight into task design and transfer strategies that could improve generalization in real-world deep learning.
Abstract
Much attention has been devoted recently to the generalization puzzle in deep learning: large, deep networks can generalize well, but existing theories bounding generalization error are exceedingly loose, and thus cannot explain this striking performance. Furthermore, a major hope is that knowledge may transfer across tasks, so that multi-task learning can improve generalization on individual tasks. However we lack analytic theories that can quantitatively predict how the degree of knowledge transfer depends on the relationship between the tasks. We develop an analytic theory of the nonlinear dynamics of generalization in deep linear networks, both within and across tasks. In particular, our theory provides analytic solutions to the training and testing error of deep networks as a function of training time, number of examples, network size and initialization, and the task structure and SNR. Our theory reveals that deep networks progressively learn the most important task structure first, so that generalization error at the early stopping time primarily depends on task structure and is independent of network size. This suggests any tight bound on generalization error must take into account task structure, and explains observations about real data being learned faster than random data. Intriguingly our theory also reveals the existence of a learning algorithm that proveably out-performs neural network training through gradient descent. Finally, for transfer learning, our theory reveals that knowledge transfer depends sensitively, but computably, on the SNRs and input feature alignments of pairs of tasks.
