Table of Contents
Fetching ...

Breaking Data Symmetry is Needed For Generalization in Feature Learning Kernels

Marcel Tomàs Bernal, Neil Rohit Mallinar, Mikhail Belkin

Abstract

Grokking occurs when a model achieves high training accuracy but generalization to unseen test points happens long after that. This phenomenon was initially observed on a class of algebraic problems, such as learning modular arithmetic (Power et al., 2022). We study grokking on algebraic tasks in a class of feature learning kernels via the Recursive Feature Machine (RFM) algorithm (Radhakrishnan et al., 2024), which iteratively updates feature matrices through the Average Gradient Outer Product (AGOP) of an estimator in order to learn task-relevant features. Our main experimental finding is that generalization occurs only when a certain symmetry in the training set is broken. Furthermore, we empirically show that RFM generalizes by recovering the underlying invariance group action inherent in the data. We find that the learned feature matrices encode specific elements of the invariance group, explaining the dependence of generalization on symmetry.

Breaking Data Symmetry is Needed For Generalization in Feature Learning Kernels

Abstract

Grokking occurs when a model achieves high training accuracy but generalization to unseen test points happens long after that. This phenomenon was initially observed on a class of algebraic problems, such as learning modular arithmetic (Power et al., 2022). We study grokking on algebraic tasks in a class of feature learning kernels via the Recursive Feature Machine (RFM) algorithm (Radhakrishnan et al., 2024), which iteratively updates feature matrices through the Average Gradient Outer Product (AGOP) of an estimator in order to learn task-relevant features. Our main experimental finding is that generalization occurs only when a certain symmetry in the training set is broken. Furthermore, we empirically show that RFM generalizes by recovering the underlying invariance group action inherent in the data. We find that the learned feature matrices encode specific elements of the invariance group, explaining the dependence of generalization on symmetry.

Paper Structure

This paper contains 47 sections, 6 theorems, 43 equations, 16 figures, 1 table.

Key Result

Proposition 3.2

The symmetry group of modular addition $f(a,b)=a+b \,\mathrm{mod}\, p$ is isomorphic to the dihedral group $D_{2p}$, with a rotation $r(a,b)=(a+1,b-1)$ and a reflection $s(a,b)=(b,a)$. The group has $2p$ elements, which can be classified into for $k \in \mathbb{Z}_p$. $\blacktriangleleft$$\blacktriangleleft$

Figures (16)

  • Figure 1: For addition modulo $p=53$, we train Gaussian (top row) and quadratic kernels (bottom row) for 60 iterations without the fixed points under reflection $s$, with shape $(a,a)$, and move random points from train to test, which enables generalization. We also move points from train to test by symmetric pairs under the reflection $s$, which doesn't help with generalization. The loss is normalized to map the highest value to 1, and is added to show its evolution. Curves averaged over 5 independent runs.
  • Figure 2: RFM with a Gaussian kernel on modular arithmetic mod $p = 29$ trained for 60 iterations. The first row shows the features learned when trained on a random partition for modular addition, subtraction, multiplication and division. We compare the AGOP learned by RFM after withholding the fixed points under a given reflection (second row) to the permutation representation of said reflection (third row).
  • Figure 3: AGOPs learned by RFM with a Gaussian kernel trained on addition mod 32 on all data samples except for the fixed points of all the reflections $sr^k$ of the dihedral subgroups, in order from left to right: $H=\langle s\rangle$, $H=\langle r^{16},s\rangle$, $H=\langle r^{8},s\rangle$, $H=\langle r^{4},s\rangle$. RFM doesn't generalize to the withheld points in any of these settings.
  • Figure 4: We train RFM with a Gaussian kernel on addition modulo $p=29$ on 50% of the data with $M_0$ encoding the subgroup $H=\{\text{id},s\}$ (top row) and $H\{\text{id},s^{10}\}$ (bottom row). The reflection axis is marked by the black line. There is a perfect match between the correct predictions and our theoretical prediction (Finding \ref{['claim:generalization']}).
  • Figure 5: (A) For addition with reflection $s$ (left) and multiplication with reflection $sr^{13}$ (right) mod 53, we train RFM with a Gaussian kernel for 60 iterations without the fixed points under that reflection, and move fixed points from test to train. This doesn't recover generalization. (B) AGOPs learned by RFM with a Gaussian kernel on addition modulo 61 on the full dataset except two random samples. In the first image, both withheld points are fixed under the same reflection $s$ (non-generalizing partition), while in the rest the withheld points are fixed under different reflections (generalizing partition).
  • ...and 11 more figures

Theorems & Definitions (11)

  • Definition 2.1: Average Gradient Outer Product (AGOP)
  • Proposition 3.2
  • Proposition 3.3
  • Lemma A.1
  • proof
  • Theorem A.2
  • proof
  • Proposition A.3
  • proof
  • Proposition C.1
  • ...and 1 more