Learning Equivariant Models by Discovering Symmetries with Learnable Augmentations
Eduardo Santos-Escriche, Stefanie Jegelka
TL;DR
SEMoLA addresses the challenge of learning symmetry-aware models without prior symmetry knowledge by jointly discovering continuous symmetries via learnable Lie-algebra-based augmentations and encoding approximate equivariance into unconstrained predictors. It introduces a LieAugmenter that samples group elements from a learned Lie algebra basis and a multi-term training objective that couples symmetry discovery with predictive accuracy and regularization for interpretability. Empirical results across RotatedMNIST, N-body dynamics, QM9, and CRC demonstrate robust symmetry discovery, competitive equivariance, and strong task performance, often matching or exceeding hard or soft equivariant baselines. The work provides a practical, interpretable framework for flexible symmetry learning with potential applicability to diverse scientific domains and symmetry structures.
Abstract
Recently, a trend has emerged that favors shifting away from designing constrained equivariant architectures for data in geometric domains and instead (1) modifying the training protocol, e.g., with a specific loss and data augmentations (soft equivariance), or (2) ignoring equivariance and inferring it only implicitly. However, both options have limitations, e.g., soft equivariance still requires a priori knowledge about the underlying symmetries, while implicitly learning equivariance from data lacks interpretability. To address these limitations, we propose SEMoLA, an end-to-end approach that jointly (1) discovers a priori unknown symmetries in the data via learnable data augmentations, and uses them to (2) encode the respective approximate equivariance into arbitrary unconstrained models. Hence, it enables learning equivariant models that do not need prior knowledge about symmetries, offer interpretability, and maintain robustness to distribution shifts. Empirically, we demonstrate the ability of SEMoLA to robustly discover relevant symmetries while achieving high prediction performance across various datasets, encompassing multiple data modalities and underlying symmetry groups.
