Table of Contents
Fetching ...

Training Dynamics of Learning 3D-Rotational Equivariance

Max W. Shen, Ewa Nowara, Michael Maser, Kyunghyun Cho

TL;DR

This work introduces a principled twirl/twist framework to quantify how much of a model’s loss stems from failing to be equivariant under 3D rotations, decomposing the total loss into a mean-prediction term and an equivariance term. It provides exact and unbiased finite-sample estimators for these components, and demonstrates across three high-dimensional molecular tasks that 3D-rotational equivariance is learned rapidly, with the equivariance error shrinking to a small fraction of the total loss early in training. The authors show the equivariance loss landscape is markedly smoother than the main loss, enabling quick optimization, and they connect equivariance error to gradients and parameter-space deviations, including a quadratic relationship in a decomposed parameter space. The results explain when symmetry-respecting architectures offer an advantage and highlight avenues to narrow the efficiency gap, such as architectural design and test-time twirling, with broader applicability to other symmetry groups.

Abstract

While data augmentation is widely used to train symmetry-agnostic models, it remains unclear how quickly and effectively they learn to respect symmetries. We investigate this by deriving a principled measure of equivariance error that, for convex losses, calculates the percent of total loss attributable to imperfections in learned symmetry. We focus our empirical investigation to 3D-rotation equivariance on high-dimensional molecular tasks (flow matching, force field prediction, denoising voxels) and find that models reduce equivariance error quickly to $\leq$2\% held-out loss within 1k-10k training steps, a result robust to model and dataset size. This happens because learning 3D-rotational equivariance is an easier learning task, with a smoother and better-conditioned loss landscape, than the main prediction task. For 3D rotations, the loss penalty for non-equivariant models is small throughout training, so they may achieve lower test loss than equivariant models per GPU-hour unless the equivariant ``efficiency gap'' is narrowed. We also experimentally and theoretically investigate the relationships between relative equivariance error, learning gradients, and model parameters.

Training Dynamics of Learning 3D-Rotational Equivariance

TL;DR

This work introduces a principled twirl/twist framework to quantify how much of a model’s loss stems from failing to be equivariant under 3D rotations, decomposing the total loss into a mean-prediction term and an equivariance term. It provides exact and unbiased finite-sample estimators for these components, and demonstrates across three high-dimensional molecular tasks that 3D-rotational equivariance is learned rapidly, with the equivariance error shrinking to a small fraction of the total loss early in training. The authors show the equivariance loss landscape is markedly smoother than the main loss, enabling quick optimization, and they connect equivariance error to gradients and parameter-space deviations, including a quadratic relationship in a decomposed parameter space. The results explain when symmetry-respecting architectures offer an advantage and highlight avenues to narrow the efficiency gap, such as architectural design and test-time twirling, with broader applicability to other symmetry groups.

Abstract

While data augmentation is widely used to train symmetry-agnostic models, it remains unclear how quickly and effectively they learn to respect symmetries. We investigate this by deriving a principled measure of equivariance error that, for convex losses, calculates the percent of total loss attributable to imperfections in learned symmetry. We focus our empirical investigation to 3D-rotation equivariance on high-dimensional molecular tasks (flow matching, force field prediction, denoising voxels) and find that models reduce equivariance error quickly to 2\% held-out loss within 1k-10k training steps, a result robust to model and dataset size. This happens because learning 3D-rotational equivariance is an easier learning task, with a smoother and better-conditioned loss landscape, than the main prediction task. For 3D rotations, the loss penalty for non-equivariant models is small throughout training, so they may achieve lower test loss than equivariant models per GPU-hour unless the equivariant ``efficiency gap'' is narrowed. We also experimentally and theoretically investigate the relationships between relative equivariance error, learning gradients, and model parameters.

Paper Structure

This paper contains 39 sections, 14 theorems, 56 equations, 14 figures.

Key Result

Proposition 1

If $l(z,y) = \frac{1}{D} \|z-y \|^2$ is mean-squared error, then the total loss decomposes as: $\mathcal{L}(f) = \mathbb{E}_{x,y} [ l( \mu(x), y )] + \frac{1}{D} \mathbb{E}_{x,T} \left[ \| (T^{-1} \circ f \circ T)(x) - \mu(x) \|^2 \right]$.

Figures (14)

  • Figure 1: Overview of the paper. (a) Schematic of twisting and twirling, which underpin a principled measure of equivariance error. (b) Loss decomposition by Taylor expansion around the twirled prediction. (c) Loss landscapes for each loss component at early model checkpoints (step=500). (d) Architectures of three non-equivariant models studied here. (e) For MSE loss, the loss decomposition holds exactly, enabling computing the percent validation loss from equivariance error, which is plotted by training step in three settings.
  • Figure 2: Training dynamics of learning equivariance in EScAIP (Force field prediction). (a-c) Validation losses and percent validation loss from equivariance error during training, early in training (a), with log-log axes (b), and decomposed into separate terms (c). (d-f) Impact of varying training set size (d), model size (e), and optimizer or learning rate (f).
  • Figure 3: Training dynamics of learning equivariance in Proteína (Flow matching). Colors indicate flow matching time, with noise at $t=0$ and data at $t=1$. (a) Percent validation loss from equivariance error during training. (b) Bar plot of the percent validation loss from equivariance error, by flow matching time, at a final checkpoint after 1M training steps. (c-d) Validation losses by training step. (e-h) Impact of varying model size (e), training set size (f-h).
  • Figure 4: Training dynamics of learning equivariance in VoxMol (Denoising voxelized atomic densities). (a) Percent validation loss from equivariance error during training. (b-c) Validation losses by training step.
  • Figure 5: Bootstrapped estimate of standard error of percent loss from equivariance error, for a model with 4 percent loss from equivariance error.
  • ...and 9 more figures

Theorems & Definitions (27)

  • Proposition 1
  • Proposition 2
  • proof
  • Proposition 3
  • proof
  • Theorem 4
  • proof
  • Theorem 5
  • proof
  • Theorem 6
  • ...and 17 more