Table of Contents
Fetching ...

Dynamics of Meta-learning Representation in the Teacher-student Scenario

Hui Wang, Cho Tung Yip, Bo Li

TL;DR

The paper addresses the theoretical dynamics of gradient-based meta-learning in nonlinear two-layer networks under streaming tasks, seeking to explain the emergence of a shared meta-representation. It employs a statistical-physics framework to derive macroscopic order-parameter dynamics that track overlaps between meta-learner and meta-teacher representations, and to quantify meta-generalization via a derived set of ODEs. Key findings include a symmetry-breaking, specialization path where meta-learner units align with distinct meta-teacher units, the critical role of learning rates and overparameterization, and robustness to some variability in task mappings. The work provides a principled lens to study meta-learning behavior and offers guidance for hyperparameter choices and model design, with potential extensions to other activations and regularization schemes.

Abstract

Gradient-based meta-learning algorithms have gained popularity for their ability to train models on new tasks using limited data. Empirical observations indicate that such algorithms are able to learn a shared representation across tasks, which is regarded as a key factor in their success. However, the in-depth theoretical understanding of the learning dynamics and the origin of the shared representation remains underdeveloped. In this work, we investigate the meta-learning dynamics of nonlinear two-layer neural networks trained on streaming tasks in the teacher-student scenario. Through the lens of statistical physics analysis, we characterize the macroscopic behavior of the meta-training processes, the formation of the shared representation, and the generalization ability of the model on new tasks. The analysis also points to the importance of the choice of certain hyperparameters of the learning algorithms.

Dynamics of Meta-learning Representation in the Teacher-student Scenario

TL;DR

The paper addresses the theoretical dynamics of gradient-based meta-learning in nonlinear two-layer networks under streaming tasks, seeking to explain the emergence of a shared meta-representation. It employs a statistical-physics framework to derive macroscopic order-parameter dynamics that track overlaps between meta-learner and meta-teacher representations, and to quantify meta-generalization via a derived set of ODEs. Key findings include a symmetry-breaking, specialization path where meta-learner units align with distinct meta-teacher units, the critical role of learning rates and overparameterization, and robustness to some variability in task mappings. The work provides a principled lens to study meta-learning behavior and offers guidance for hyperparameter choices and model design, with potential extensions to other activations and regularization schemes.

Abstract

Gradient-based meta-learning algorithms have gained popularity for their ability to train models on new tasks using limited data. Empirical observations indicate that such algorithms are able to learn a shared representation across tasks, which is regarded as a key factor in their success. However, the in-depth theoretical understanding of the learning dynamics and the origin of the shared representation remains underdeveloped. In this work, we investigate the meta-learning dynamics of nonlinear two-layer neural networks trained on streaming tasks in the teacher-student scenario. Through the lens of statistical physics analysis, we characterize the macroscopic behavior of the meta-training processes, the formation of the shared representation, and the generalization ability of the model on new tasks. The analysis also points to the importance of the choice of certain hyperparameters of the learning algorithms.
Paper Structure (23 sections, 57 equations, 12 figures)

This paper contains 23 sections, 57 equations, 12 figures.

Figures (12)

  • Figure 1: The framework of the teacher-student analysis for meta-learning under consideration. All models are two-layer neural networks. (a) The meta-teacher generates different task-specific teachers, which share the same input-to-hidden weights with the meta-teacher (denoted as $\mathbf{B}$) but have different hidden-to-output weights (denoted as $\boldsymbol{u}^{t}$). (b) The meta-training and testing processes. The meta-learner is meta-trained on different tasks using the FO-ANIL algorithm. Each specific learner inherits the input-to-hidden weights $\mathbf{J}$ from the meta-learner, but maintains its own hidden-to-output weights (denoted as $\boldsymbol{w}^{t}$) which are task specific. The meta-testing is evaluated by its ability to guide a new learner to solve an unseen task $\mathcal{T}_{\text{new}}$ (through an inner-loop optimization).
  • Figure 2: Comparison between theoretical predictions (marked by "$\text{TH}$") and simulated experiments (marked by "$\text{Exp}$") of the online meta-learning dynamics under consideration. The simulation setup is outlined in Sec. \ref{['sec:compare']}, with system parameters $N=1000, K=M=3, P=V=100, \eta_J=6, \eta_w=4$. The simulations are averaged over $10$ independent runs, each with different realizations of training tasks and datasets, while maintaining the same initial conditions. The error bars represent one standard deviation across these $10$ trials. (a) The meta-generalization error $\epsilon_g^{\text{meta}}$ as a function of the normalized task index $\alpha = t / N$. Since directly evaluating $\epsilon_g^{\text{meta}}$ by testing on a large dataset is time-consuming, we instead prepare a small test dataset at each time step and compute a moving average of $\epsilon_g^{\text{meta}}$ over time to smooth out fluctuations for each trial, with a sliding window $\Delta \alpha = 0.05$. Panels (b) and (c) depict the dynamic behavior of $Q$, where solid lines indicate experimental observations and dashed lines show theoretical results. Panels (d), (e), and (f) illustrate the dynamical evolution of $R$, with solid lines representing experimental data and dashed lines corresponding to theoretical predictions.
  • Figure 3: A case study of meta-representation learning. The system parameters are $K = M = 3, \eta_J = 3, T_{mn} = m \delta_{m,n}$, and the initial conditions are $Q_{kl} = \frac{1}{2}\delta_{k,l}, R_{kn} = 10^{-12}$. (a) $\epsilon_g^{\text{meta}}$ vs $\alpha$. (b) $R_{kn}$ vs $\alpha$ for $\eta_w = 3$. (c) $\rho_{kn}$ vs $\alpha$ for $\eta_w = 3$. (d) Pictorial illustration of the outcome of the meta-representation learning for $\eta_w = 3$. (e) $R_{kn}$ vs $\alpha$ for $\eta_w = 9$. (f) $\rho_{kn}$ vs $\alpha$ for $\eta_w = 9$. (g) Pictorial illustration of the outcome of the meta-representation learning for $\eta_w = 9$.
  • Figure 4: Effect of learning rates for meta-generalization. The system parameters and initial conditions are the same as those in Sec. \ref{['sec:representation_KM3_case']}. (a) $\epsilon_g^{\text{meta}}$ vs $\eta_w$ when fixing $\eta_J$. (b) $\epsilon_g^{\text{meta}}$ vs $\eta_J$ when fixing $\eta_w$.
  • Figure 5: The values of $\tilde{\alpha}$ needed to reach $\epsilon_g^{\text{meta}} = 0.01$ under different choices of the learning rates $\eta_w$ and $\eta_J$. The yellow region indicates that $\epsilon_g^{\text{meta}}$ fails to drop to the level $0.01$ within the time window $\alpha_{\text{final}} = 450$ considered. The system parameters and initial conditions are set as $M = 3, T_{mn} = m \delta_{m,n}, Q_{kl} = \frac{1}{2} \delta_{k,l}, R_{k,n} = 10^{-12}$. Panels (a) to (g) correspond to $K = 3, 4, 5, 6, 7, 8, 9$, respectively.
  • ...and 7 more figures