Table of Contents
Fetching ...

There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average

Ben Athiwaratkun, Marc Finzi, Pavel Izmailov, Andrew Gordon Wilson

TL;DR

The paper analyzes why consistency-regularized semi-supervised learning benefits from averaging multiple plausible solutions rather than converging to a single SGD minimum. It shows that the consistency loss encourages flatter regions by implicitly penalizing the Jacobian and Hessian, and that SGD traverses a broad, diverse set of models late in training. By applying Stochastic Weight Averaging (SWA) and introducing fast-SWA, the authors achieve new state-of-the-art results on CIFAR-10/100 with varying amounts of labeled data and even improve a domain-adaptation baseline. The approach offers practical, test-time efficiency advantages and demonstrates that leveraging trajectory diversity is a powerful way to boost semi-supervised learning performance.

Abstract

Presently the most successful approaches to semi-supervised learning are based on consistency regularization, whereby a model is trained to be robust to small perturbations of its inputs and parameters. To understand consistency regularization, we conceptually explore how loss geometry interacts with training procedures. The consistency loss dramatically improves generalization performance over supervised-only training; however, we show that SGD struggles to converge on the consistency loss and continues to make large steps that lead to changes in predictions on the test data. Motivated by these observations, we propose to train consistency-based methods with Stochastic Weight Averaging (SWA), a recent approach which averages weights along the trajectory of SGD with a modified learning rate schedule. We also propose fast-SWA, which further accelerates convergence by averaging multiple points within each cycle of a cyclical learning rate schedule. With weight averaging, we achieve the best known semi-supervised results on CIFAR-10 and CIFAR-100, over many different quantities of labeled training data. For example, we achieve 5.0% error on CIFAR-10 with only 4000 labels, compared to the previous best result in the literature of 6.3%.

There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average

TL;DR

The paper analyzes why consistency-regularized semi-supervised learning benefits from averaging multiple plausible solutions rather than converging to a single SGD minimum. It shows that the consistency loss encourages flatter regions by implicitly penalizing the Jacobian and Hessian, and that SGD traverses a broad, diverse set of models late in training. By applying Stochastic Weight Averaging (SWA) and introducing fast-SWA, the authors achieve new state-of-the-art results on CIFAR-10/100 with varying amounts of labeled data and even improve a domain-adaptation baseline. The approach offers practical, test-time efficiency advantages and demonstrates that leveraging trajectory diversity is a powerful way to boost semi-supervised learning performance.

Abstract

Presently the most successful approaches to semi-supervised learning are based on consistency regularization, whereby a model is trained to be robust to small perturbations of its inputs and parameters. To understand consistency regularization, we conceptually explore how loss geometry interacts with training procedures. The consistency loss dramatically improves generalization performance over supervised-only training; however, we show that SGD struggles to converge on the consistency loss and continues to make large steps that lead to changes in predictions on the test data. Motivated by these observations, we propose to train consistency-based methods with Stochastic Weight Averaging (SWA), a recent approach which averages weights along the trajectory of SGD with a modified learning rate schedule. We also propose fast-SWA, which further accelerates convergence by averaging multiple points within each cycle of a cyclical learning rate schedule. With weight averaging, we achieve the best known semi-supervised results on CIFAR-10 and CIFAR-100, over many different quantities of labeled training data. For example, we achieve 5.0% error on CIFAR-10 with only 4000 labels, compared to the previous best result in the literature of 6.3%.

Paper Structure

This paper contains 42 sections, 17 equations, 15 figures, 7 tables.

Figures (15)

  • Figure 1: (a): The evolution of the gradient norm for the consistency regularization term (Cons) and the cross-entropy term (CE) in the $\Pi$, MT, and standard supervised (CE only) models during training. (b): Train and test errors along rays connecting two SGD solutions for each respective model. (c) and (d): Comparison of errors along rays connecting two SGD solutions, random rays, and adversarial rays for the $\Pi$ and supervised models. See Section \ref{['sec:sup_rayplots']} for the analogous Mean Teacher model's plot.
  • Figure 6: (a): Illustration of a convex and non-convex function and Jensen's inequality. (b): Scatter plot of the decrease in error $C_{\text{avg}}$ for weight averaging versus distance. (c): Scatter plot of the decrease in error $C_{\text{ens}}$ for prediction ensembling versus diversity. (d): Train error surface (orange) and Test error surface (blue). The SGD solutions (red dots) around a locally flat minimum are far apart due to the flatness of the train surface (see Figure \ref{['fig:sgd_sgd_rays']}) which leads to large error reduction of the SWA solution (blue dot).
  • Figure 11: Left: Cyclical cosine learning rate schedule and SWA and fast-SWA averaging strategies. Middle: Illustration of the solutions explored by the cyclical cosine annealing schedule on an error surface. Right: Illustration of SWA and fast-SWA averaging strategies. fast-SWA averages more points but the errors of the averaged points, as indicated by the heat color, are higher.
  • Figure 12:
  • Figure 13:
  • ...and 10 more figures