Train Faster, Perform Better: Modular Adaptive Training in Over-Parameterized Models
Yubin Shi, Yixuan Chen, Mingzhi Dong, Xiaochen Yang, Dongsheng Li, Yujiang Wang, Robert P. Dick, Qin Lv, Yingying Zhao, Fan Yang, Tun Lu, Ning Gu, Li Shang
TL;DR
The paper addresses the high computational cost of training over-parameterized models by introducing modular Neural Tangent Kernel (mNTK) to quantify per-module learning dynamics. It demonstrates that the principal eigenvalue $λ_{\text{max}}$ of each module's mNTK is a strong indicator of trainability and generalization, motivating Modular Adaptive Training (MAT) that selectively updates modules with large $λ_{\text{max}}$ using a dynamic threshold. MAT significantly reduces training FLOPs while improving or matching accuracy across BERT, Switch-Transformer, and VGG, by enforcing sparse, modular backpropagation and preventing overfitting in less informative modules. These findings offer a practical, theory-informed path to faster, more efficient training of large, structured neural networks and can complement existing pruning and efficiency methods.
Abstract
Despite their prevalence in deep-learning communities, over-parameterized models convey high demands of computational costs for proper training. This work studies the fine-grained, modular-level learning dynamics of over-parameterized models to attain a more efficient and fruitful training strategy. Empirical evidence reveals that when scaling down into network modules, such as heads in self-attention models, we can observe varying learning patterns implicitly associated with each module's trainability. To describe such modular-level learning capabilities, we introduce a novel concept dubbed modular neural tangent kernel (mNTK), and we demonstrate that the quality of a module's learning is tightly associated with its mNTK's principal eigenvalue $λ_{\max}$. A large $λ_{\max}$ indicates that the module learns features with better convergence, while those miniature ones may impact generalization negatively. Inspired by the discovery, we propose a novel training strategy termed Modular Adaptive Training (MAT) to update those modules with their $λ_{\max}$ exceeding a dynamic threshold selectively, concentrating the model on learning common features and ignoring those inconsistent ones. Unlike most existing training schemes with a complete BP cycle across all network modules, MAT can significantly save computations by its partially-updating strategy and can further improve performance. Experiments show that MAT nearly halves the computational cost of model training and outperforms the accuracy of baselines.
