Table of Contents
Fetching ...

Improving Convergence and Generalization Using Parameter Symmetries

Bo Zhao, Robert M. Gower, Robin Walters, Rose Yu

TL;DR

This paper addresses how parameter-space symmetries in neural networks can be exploited to accelerate optimization and improve generalization by using teleportation, a loss-invariant move within a loss level set. It develops a formal framework where a symmetry group $G$ acts on parameters so that $\\mathcal{L}(\\mathbf{w}) = \\mathcal{L}(g\\cdot \\mathbf{w})$, and introduces SGD with teleportation, proving orbit-wide convergence guarantees and linking a single teleportation to (damped) Newton steps under suitable conditions. It further introduces curvature-based objectives for the minimum, showing that teleportation toward higher curvature minima can improve generalization, with empirical correlations between curvature and validation loss across MNIST, Fashion-MNIST, and CIFAR-10. Finally, it demonstrates that teleportation is broadly compatible with standard optimizers and can be incorporated via meta-learning to learn when and where to teleport, highlighting the potential of symmetry-inspired optimization to enhance convergence and generalization in deep learning.

Abstract

In many neural networks, different values of the parameters may result in the same loss value. Parameter space symmetries are loss-invariant transformations that change the model parameters. Teleportation applies such transformations to accelerate optimization. However, the exact mechanism behind this algorithm's success is not well understood. In this paper, we show that teleportation not only speeds up optimization in the short-term, but gives overall faster time to convergence. Additionally, teleporting to minima with different curvatures improves generalization, which suggests a connection between the curvature of the minimum and generalization ability. Finally, we show that integrating teleportation into a wide range of optimization algorithms and optimization-based meta-learning improves convergence. Our results showcase the versatility of teleportation and demonstrate the potential of incorporating symmetry in optimization.

Improving Convergence and Generalization Using Parameter Symmetries

TL;DR

This paper addresses how parameter-space symmetries in neural networks can be exploited to accelerate optimization and improve generalization by using teleportation, a loss-invariant move within a loss level set. It develops a formal framework where a symmetry group acts on parameters so that , and introduces SGD with teleportation, proving orbit-wide convergence guarantees and linking a single teleportation to (damped) Newton steps under suitable conditions. It further introduces curvature-based objectives for the minimum, showing that teleportation toward higher curvature minima can improve generalization, with empirical correlations between curvature and validation loss across MNIST, Fashion-MNIST, and CIFAR-10. Finally, it demonstrates that teleportation is broadly compatible with standard optimizers and can be incorporated via meta-learning to learn when and where to teleport, highlighting the potential of symmetry-inspired optimization to enhance convergence and generalization in deep learning.

Abstract

In many neural networks, different values of the parameters may result in the same loss value. Parameter space symmetries are loss-invariant transformations that change the model parameters. Teleportation applies such transformations to accelerate optimization. However, the exact mechanism behind this algorithm's success is not well understood. In this paper, we show that teleportation not only speeds up optimization in the short-term, but gives overall faster time to convergence. Additionally, teleporting to minima with different curvatures improves generalization, which suggests a connection between the curvature of the minimum and generalization ability. Finally, we show that integrating teleportation into a wide range of optimization algorithms and optimization-based meta-learning improves convergence. Our results showcase the versatility of teleportation and demonstrate the potential of incorporating symmetry in optimization.
Paper Structure (42 sections, 11 theorems, 96 equations, 13 figures, 1 table, 2 algorithms)

This paper contains 42 sections, 11 theorems, 96 equations, 13 figures, 1 table, 2 algorithms.

Key Result

Theorem 3.1

(Smooth non-convex) Let $\mathcal{L}({\bm{w}}, \xi)$ be $\beta$--smooth and let Consider the iterates ${\bm{w}}^t$ given by equation eq:stochgradg where which we assume exists. If $\eta = \frac{1}{\beta \sqrt{ T-1}}$ then where the expectation is the total expectation with respect to the data $\xi^t$ for $t=0, \ldots, T-1.$

Figures (13)

  • Figure 1: With teleportation, SGD converges to a basin where all points on the level set are stationary points.
  • Figure 2: Gradient flow ($\mathcal{L}({\bm{w}})$) and a curve on the minimum ($\gamma$). The curvature of both curves may affect generalization.
  • Figure 3: Illustration of the effect of sharpness (a,b) and curvature (c,d) of minima on generalization. See Figure \ref{['fig:curvatures']} for a 3D visualization of the curves $\mathcal{L}({\bm{w}})$ and $\gamma$. When the loss landscape shifts due to a change in data distribution, sharper minima have larger increase in loss. In the example shown, minima with larger curvature moves further away from the shifted minima.
  • Figure 4: Changing sharpness (left) or curvature (right) using teleportation and its effect on generalization on CIFAR-10. Solid line represents average test loss, and dashed line represent average training loss. Teleporting to decrease sharpness improves validation loss slightly. Teleportation changing curvatures has a more significant impact on generalization ability.
  • Figure 5: Integrating teleportation with AdaGrad, momentum, RMSProp, and Adam improves the convergence rate on MNIST. Solid line represents the average test loss, and dashed line represents the average training loss. Shaded areas are 1 standard deviation of the test loss across 5 runs.
  • ...and 8 more figures

Theorems & Definitions (23)

  • Theorem 3.1
  • Proposition 3.2: Quadratic term in convergence rate
  • Definition 3.3
  • Proposition 3.4
  • Proposition 3.5
  • Proposition 3.6
  • Lemma A.1: Descent Lemma
  • proof
  • proof
  • Proposition A.2
  • ...and 13 more