Table of Contents
Fetching ...

SeWA: Selective Weight Average via Probabilistic Masking

Peng Wang, Shengchao Hu, Zerui Tao, Guoxia Wang, Dianhai Yu, Li Shen, Quan Zheng, Dacheng Tao

TL;DR

SeWA addresses the reliance on manually designed checkpoint sampling in weight averaging by adaptively selecting a small set of final-stage checkpoints through a probabilistic mask learned with the Gumbel-Softmax. It reframes the discrete subset selection as a continuous optimization and derives stability-based generalization bounds that are sharper than SGD for both convex and non-convex cases, with bounds scaling as $\epsilon_{gen}=\mathcal{O}$-type terms depending on the window size $k$, total iterations $T$, and mask sparsity $s$. The approach is instantiated practically by relaxing binary masks to differentiable variables, enabling gradient-based optimization via GS sampling and Monte Carlo estimation. Empirically, SeWA achieves comparable or superior performance to methods requiring many more checkpoints across behavior cloning, image classification, and text classification, while using far fewer averaged points, demonstrating improved efficiency and robustness in unstable training trajectories.

Abstract

Weight averaging has become a standard technique for enhancing model performance. However, methods such as Stochastic Weight Averaging (SWA) and Latest Weight Averaging (LAWA) often require manually designed procedures to sample from the training trajectory, and the results depend heavily on hyperparameter tuning. To minimize human effort, this paper proposes a simple yet efficient algorithm called Selective Weight Averaging (SeWA), which adaptively selects checkpoints during the final stages of training for averaging. Based on SeWA, we show that only a few points are needed to achieve better generalization and faster convergence. Theoretically, solving the discrete subset selection problem is inherently challenging. To address this, we transform it into a continuous probabilistic optimization framework and employ the Gumbel-Softmax estimator to learn the non-differentiable mask for each checkpoint. Further, we theoretically derive the SeWA's stability-based generalization bounds, which are sharper than that of SGD under both convex and non-convex assumptions. Finally, solid extended experiments in various domains, including behavior cloning, image classification, and text classification, further validate the effectiveness of our approach.

SeWA: Selective Weight Average via Probabilistic Masking

TL;DR

SeWA addresses the reliance on manually designed checkpoint sampling in weight averaging by adaptively selecting a small set of final-stage checkpoints through a probabilistic mask learned with the Gumbel-Softmax. It reframes the discrete subset selection as a continuous optimization and derives stability-based generalization bounds that are sharper than SGD for both convex and non-convex cases, with bounds scaling as -type terms depending on the window size , total iterations , and mask sparsity . The approach is instantiated practically by relaxing binary masks to differentiable variables, enabling gradient-based optimization via GS sampling and Monte Carlo estimation. Empirically, SeWA achieves comparable or superior performance to methods requiring many more checkpoints across behavior cloning, image classification, and text classification, while using far fewer averaged points, demonstrating improved efficiency and robustness in unstable training trajectories.

Abstract

Weight averaging has become a standard technique for enhancing model performance. However, methods such as Stochastic Weight Averaging (SWA) and Latest Weight Averaging (LAWA) often require manually designed procedures to sample from the training trajectory, and the results depend heavily on hyperparameter tuning. To minimize human effort, this paper proposes a simple yet efficient algorithm called Selective Weight Averaging (SeWA), which adaptively selects checkpoints during the final stages of training for averaging. Based on SeWA, we show that only a few points are needed to achieve better generalization and faster convergence. Theoretically, solving the discrete subset selection problem is inherently challenging. To address this, we transform it into a continuous probabilistic optimization framework and employ the Gumbel-Softmax estimator to learn the non-differentiable mask for each checkpoint. Further, we theoretically derive the SeWA's stability-based generalization bounds, which are sharper than that of SGD under both convex and non-convex assumptions. Finally, solid extended experiments in various domains, including behavior cloning, image classification, and text classification, further validate the effectiveness of our approach.

Paper Structure

This paper contains 28 sections, 7 theorems, 48 equations, 6 figures, 2 tables.

Key Result

Lemma 3.4

Assume that the function $F$ is $\beta$-smooth. Then, (1). (non-expansive) If $F$ is convex, for any $\alpha \leq \frac{2}{\beta}$, we have $\Vert w_{T+1}-w_{T+1}^{\prime} \Vert \leq \Vert w_{T}-w_{T}^{\prime}\Vert$; (2). ($(1\!+\!\alpha\beta)$-expansive) If $F$ is non-convex, for any $\alpha$, we

Figures (6)

  • Figure 1: Comparison of SeWA with different models on convergence performance.
  • Figure 2: Comparison of different methods on the D4RL benchmark. Each data point represents the average cumulative reward across multiple tasks, averaged over 3 random seeds and 20 trajectories per seed. Detailed results are provided in Appendix \ref{['sec:ExpDetail']}.
  • Figure 3: From left to right, the figures illustrate the impact of the hyperparameter $K$ on the CIFAR-100 task. Each point corresponds to intervals of 100 checkpoints, with $K$ checkpoints selected and averaged from these intervals using different strategies.
  • Figure 4: From left to right, the figures illustrate the impact of the hyperparameter $K$ on the AG News corpus. Each point corresponds to intervals of 100 checkpoints, with $K$ checkpoints selected and averaged from these intervals using different strategies.
  • Figure 5: From left to right, the figures illustrate the impact of the hyperparameter $K$ on the CIFAR-100 task. Each data point represents performance based on intervals of 100 checkpoints, with $K$ checkpoints selected from these intervals using various strategies. The first row corresponds to a network architecture with 1 block, the second row represents a network with 3 blocks, and the third row depicts results for a network with 5 blocks.
  • ...and 1 more figures

Theorems & Definitions (17)

  • Lemma 3.4
  • Definition 3.5: $\epsilon$-Uniformly Stable
  • Theorem 3.6
  • Lemma 4.1
  • proof
  • Theorem 4.2
  • proof : Proof sketch
  • Remark 4.3
  • Remark 4.4
  • Lemma 4.5
  • ...and 7 more