Table of Contents
Fetching ...

Train Faster, Perform Better: Modular Adaptive Training in Over-Parameterized Models

Yubin Shi, Yixuan Chen, Mingzhi Dong, Xiaochen Yang, Dongsheng Li, Yujiang Wang, Robert P. Dick, Qin Lv, Yingying Zhao, Fan Yang, Tun Lu, Ning Gu, Li Shang

TL;DR

The paper addresses the high computational cost of training over-parameterized models by introducing modular Neural Tangent Kernel (mNTK) to quantify per-module learning dynamics. It demonstrates that the principal eigenvalue $λ_{\text{max}}$ of each module's mNTK is a strong indicator of trainability and generalization, motivating Modular Adaptive Training (MAT) that selectively updates modules with large $λ_{\text{max}}$ using a dynamic threshold. MAT significantly reduces training FLOPs while improving or matching accuracy across BERT, Switch-Transformer, and VGG, by enforcing sparse, modular backpropagation and preventing overfitting in less informative modules. These findings offer a practical, theory-informed path to faster, more efficient training of large, structured neural networks and can complement existing pruning and efficiency methods.

Abstract

Despite their prevalence in deep-learning communities, over-parameterized models convey high demands of computational costs for proper training. This work studies the fine-grained, modular-level learning dynamics of over-parameterized models to attain a more efficient and fruitful training strategy. Empirical evidence reveals that when scaling down into network modules, such as heads in self-attention models, we can observe varying learning patterns implicitly associated with each module's trainability. To describe such modular-level learning capabilities, we introduce a novel concept dubbed modular neural tangent kernel (mNTK), and we demonstrate that the quality of a module's learning is tightly associated with its mNTK's principal eigenvalue $λ_{\max}$. A large $λ_{\max}$ indicates that the module learns features with better convergence, while those miniature ones may impact generalization negatively. Inspired by the discovery, we propose a novel training strategy termed Modular Adaptive Training (MAT) to update those modules with their $λ_{\max}$ exceeding a dynamic threshold selectively, concentrating the model on learning common features and ignoring those inconsistent ones. Unlike most existing training schemes with a complete BP cycle across all network modules, MAT can significantly save computations by its partially-updating strategy and can further improve performance. Experiments show that MAT nearly halves the computational cost of model training and outperforms the accuracy of baselines.

Train Faster, Perform Better: Modular Adaptive Training in Over-Parameterized Models

TL;DR

The paper addresses the high computational cost of training over-parameterized models by introducing modular Neural Tangent Kernel (mNTK) to quantify per-module learning dynamics. It demonstrates that the principal eigenvalue of each module's mNTK is a strong indicator of trainability and generalization, motivating Modular Adaptive Training (MAT) that selectively updates modules with large using a dynamic threshold. MAT significantly reduces training FLOPs while improving or matching accuracy across BERT, Switch-Transformer, and VGG, by enforcing sparse, modular backpropagation and preventing overfitting in less informative modules. These findings offer a practical, theory-informed path to faster, more efficient training of large, structured neural networks and can complement existing pruning and efficiency methods.

Abstract

Despite their prevalence in deep-learning communities, over-parameterized models convey high demands of computational costs for proper training. This work studies the fine-grained, modular-level learning dynamics of over-parameterized models to attain a more efficient and fruitful training strategy. Empirical evidence reveals that when scaling down into network modules, such as heads in self-attention models, we can observe varying learning patterns implicitly associated with each module's trainability. To describe such modular-level learning capabilities, we introduce a novel concept dubbed modular neural tangent kernel (mNTK), and we demonstrate that the quality of a module's learning is tightly associated with its mNTK's principal eigenvalue . A large indicates that the module learns features with better convergence, while those miniature ones may impact generalization negatively. Inspired by the discovery, we propose a novel training strategy termed Modular Adaptive Training (MAT) to update those modules with their exceeding a dynamic threshold selectively, concentrating the model on learning common features and ignoring those inconsistent ones. Unlike most existing training schemes with a complete BP cycle across all network modules, MAT can significantly save computations by its partially-updating strategy and can further improve performance. Experiments show that MAT nearly halves the computational cost of model training and outperforms the accuracy of baselines.
Paper Structure (18 sections, 1 theorem, 18 equations, 6 figures, 7 tables, 1 algorithm)

This paper contains 18 sections, 1 theorem, 18 equations, 6 figures, 7 tables, 1 algorithm.

Key Result

Lemma 2

Given $R>0$, with probability at least $1-\delta$ over the random initialization $(\boldsymbol{\theta}(0), \boldsymbol{a})$, simultaneously for every $B>0$, the following function class has empirical Rademacher complexity bounded as:

Figures (6)

  • Figure 1: Characteristics of training BERT on WikiText-2. Figure (a) demonstrates the joint variation of effective rank and $\lambda_{\max}$ across the attention heads. Head$^l_i$ refers to the $i^\text{th}$ attention head in the $l^\text{th}$ layer. Figure (b-left) illustrates the idea of MAT, which governs the heads training by a dynamic threshold. Using MAT speeds up convergence and achieves lower validation loss (b-right).
  • Figure 2: Training dynamics characterized by layer-wise mNTK of BERT trained on WikiText-2. (a): The first 16 eigenvalues distribution of layer-wise mNTKs at the $10^{\text{th}}$ epoch. (b): Variation of $\lambda_{\max}$ of layer-wise mNTKs. (c): Variation of $\kappa$ of layer-wise mNTKs during the first $20$ epochs.
  • Figure 3: Training dynamics in the overfitting case of 4-layer BERT trained by 64 token MLM task.
  • Figure 4: The normalized eigen-spectrum distribution exhibits two distinct regions, termed information space and nuisance space.
  • Figure 5: Histogram of epochs where the heads are trained using back-propagation.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Definition 1
  • Lemma 2