Trainable Weight Averaging: Accelerating Training and Improving Generalization
Tao Li, Zhehao Huang, Yingwen Wu, Zhengbao He, Qinghua Tao, Xiaolin Huang, Chih-Jen Lin
TL;DR
Trainable Weight Averaging (TWA) introduces learnable coefficients for combining multiple model weights within a low-dimensional subspace, replacing static averaging schemes like SWA. By constructing a subspace from historical weights and projecting gradients onto this subspace, TWA achieves efficient optimization of the combining coefficients $\boldsymbol{\beta}$, forming $\boldsymbol{w}_{twa}=\boldsymbol{P}\boldsymbol{\beta}$. The approach supports two variants, TWA-t (training data) and TWA-v (validation data), with TWA-v consistently delivering superior generalization and robustness across CNNs, ViTs, transformers, and NLP models. A distributed training framework and low-bit quantization of the projection matrix enable scalable use on large models, while layer-wise processing further enhances efficiency. Empirical results show significant training acceleration (e.g., reductions of 40%+ in epochs on CIFAR and 30%+ on ImageNet) and improved generalization, especially in fine-tuning and transformer-based architectures.
Abstract
Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40\% on CIFAR datasets and 30\% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.
