Table of Contents
Fetching ...

Trainable Weight Averaging: Accelerating Training and Improving Generalization

Tao Li, Zhehao Huang, Yingwen Wu, Zhengbao He, Qinghua Tao, Xiaolin Huang, Chih-Jen Lin

TL;DR

Trainable Weight Averaging (TWA) introduces learnable coefficients for combining multiple model weights within a low-dimensional subspace, replacing static averaging schemes like SWA. By constructing a subspace from historical weights and projecting gradients onto this subspace, TWA achieves efficient optimization of the combining coefficients $\boldsymbol{\beta}$, forming $\boldsymbol{w}_{twa}=\boldsymbol{P}\boldsymbol{\beta}$. The approach supports two variants, TWA-t (training data) and TWA-v (validation data), with TWA-v consistently delivering superior generalization and robustness across CNNs, ViTs, transformers, and NLP models. A distributed training framework and low-bit quantization of the projection matrix enable scalable use on large models, while layer-wise processing further enhances efficiency. Empirical results show significant training acceleration (e.g., reductions of 40%+ in epochs on CIFAR and 30%+ on ImageNet) and improved generalization, especially in fine-tuning and transformer-based architectures.

Abstract

Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40\% on CIFAR datasets and 30\% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.

Trainable Weight Averaging: Accelerating Training and Improving Generalization

TL;DR

Trainable Weight Averaging (TWA) introduces learnable coefficients for combining multiple model weights within a low-dimensional subspace, replacing static averaging schemes like SWA. By constructing a subspace from historical weights and projecting gradients onto this subspace, TWA achieves efficient optimization of the combining coefficients , forming . The approach supports two variants, TWA-t (training data) and TWA-v (validation data), with TWA-v consistently delivering superior generalization and robustness across CNNs, ViTs, transformers, and NLP models. A distributed training framework and low-bit quantization of the projection matrix enable scalable use on large models, while layer-wise processing further enhances efficiency. Empirical results show significant training acceleration (e.g., reductions of 40%+ in epochs on CIFAR and 30%+ on ImageNet) and improved generalization, especially in fine-tuning and transformer-based architectures.

Abstract

Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40\% on CIFAR datasets and 30\% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.
Paper Structure (41 sections, 16 equations, 3 figures, 22 tables, 1 algorithm)

This paper contains 41 sections, 16 equations, 3 figures, 22 tables, 1 algorithm.

Figures (3)

  • Figure 1: An efficient parallel scheme for subspace training. Suppose there $k$ GPUs available for distributed training. We begin by uniformly partitioning $\boldsymbol{P}$ into $k$ sub-matrices, i.e., $[\boldsymbol{P}_1, \boldsymbol{P}_2, \cdots, \boldsymbol{P}_k]$, with each GPU storing one sub-matrix. First, we synchronize the local gradient $\boldsymbol{g}_i$ to obtain global gradient $\boldsymbol{g}_i$ using an all-reduce operation. Next, each node computes the local projected gradient $\boldsymbol{P_i}\boldsymbol{P_i}^\top \boldsymbol{g}$ independently. Finally, we perform a second all-reduce operation to obtain the global projected gradient $\boldsymbol{P}\boldsymbol{P}^\top\boldsymbol{g}$. In this way, we evenly distribute the memory and computation burden across all nodes.
  • Figure 2: Performance comparison with LAWA across different averaging epochs. The experiments are conducted on ImageNet using ViT-B/16.
  • Figure 3: Performance of TWA-v with different quantization bits for projection matrix.