Table of Contents
Fetching ...

Norm-Hierarchy Transitions in Representation Learning: When and Why Neural Networks Abandon Shortcuts

Truong Xuan Khanh, Truong Quynh Hoa

TL;DR

The Norm-Hierarchy Transition (NHT) framework is introduced, which explains delayed representation learning as the slow traversal of a hierarchy of parameter norms during regularized optimization and derives a tight bound showing that the transition delay grows logarithmically with the ratio between shortcut and structured norms.

Abstract

Neural networks often rely on spurious shortcuts for many epochs before discovering structured representations. However, the mechanism governing when this transition occurs and whether its timing can be predicted remains unclear. Prior work shows that gradient descent converges to low norm solutions and that neural networks exhibit simplicity bias, but neither explains the timescale of the transition from shortcut features to structured representations. We introduce the Norm-Hierarchy Transition (NHT) framework, which explains delayed representation learning as the slow traversal of a hierarchy of parameter norms during regularized optimization. When multiple interpolating solutions exist with different norms, weight decay gradually moves the model from high norm shortcut solutions toward lower norm structured representations. We derive a tight bound showing that the transition delay grows logarithmically with the ratio between shortcut and structured norms. Experiments on modular arithmetic, CIFAR-10 with spurious features, CelebA, and Waterbirds support the predictions of the framework. The results suggest that grokking, shortcut learning, and delayed feature discovery arise from a common mechanism based on norm hierarchy traversal during training.

Norm-Hierarchy Transitions in Representation Learning: When and Why Neural Networks Abandon Shortcuts

TL;DR

The Norm-Hierarchy Transition (NHT) framework is introduced, which explains delayed representation learning as the slow traversal of a hierarchy of parameter norms during regularized optimization and derives a tight bound showing that the transition delay grows logarithmically with the ratio between shortcut and structured norms.

Abstract

Neural networks often rely on spurious shortcuts for many epochs before discovering structured representations. However, the mechanism governing when this transition occurs and whether its timing can be predicted remains unclear. Prior work shows that gradient descent converges to low norm solutions and that neural networks exhibit simplicity bias, but neither explains the timescale of the transition from shortcut features to structured representations. We introduce the Norm-Hierarchy Transition (NHT) framework, which explains delayed representation learning as the slow traversal of a hierarchy of parameter norms during regularized optimization. When multiple interpolating solutions exist with different norms, weight decay gradually moves the model from high norm shortcut solutions toward lower norm structured representations. We derive a tight bound showing that the transition delay grows logarithmically with the ratio between shortcut and structured norms. Experiments on modular arithmetic, CIFAR-10 with spurious features, CelebA, and Waterbirds support the predictions of the framework. The results suggest that grokking, shortcut learning, and delayed feature discovery arise from a common mechanism based on norm hierarchy traversal during training.
Paper Structure (77 sections, 6 theorems, 14 equations, 5 figures, 3 tables)

This paper contains 77 sections, 6 theorems, 14 equations, 5 figures, 3 tables.

Key Result

Theorem 3.1

Under (A1)--(A3), for $\theta_t \in \mathcal{M}_\mathrm{sc}$, $V_t = \|\theta_t\|^2$ satisfies: The escape time satisfies $T_\mathrm{escape} = \Theta\bigl(\gamma_\mathrm{eff}^{-1} \log(V_\mathrm{sc}/V_\mathrm{st})\bigr)$, where $\gamma_\mathrm{eff} = \eta\lambda$ for SGD and $\gamma_\mathrm{eff} \geq \eta\lambda$ for AdamW.

Figures (5)

  • Figure 1: Three-regime structure under weight-decay sweep (CIFAR-10, $\rho=0.95$, 7 values of $\lambda$).Left: $\|\theta\|^2$ trajectory over 200 epochs. Weak $\lambda$ ($0.001$--$0.01$): norm grows monotonically to 5,000--14,000 with ${<}1\%$ decay, indicating persistent shortcut reliance. Intermediate $\lambda$ ($0.05$--$0.3$): norm peaks then decays up to $21.6\%$, the signature of a delayed NHT. Strong $\lambda$ ($0.5$--$1.0$): norm suppressed from epoch 1, decaying $42$--$64\%$, indicating learning suppression. Right: Final clean accuracy vs. $\lambda$ (circles) and norm-decay percentage (crosses). Clean accuracy peaks in the intermediate regime (${\approx}58\%$), confirming that the delayed transition corresponds to real-feature acquisition. Error bars: $\pm1$ std over 4 seeds.
  • Figure 2: Effect of spurious correlation strength $\rho$ on transition outcome (CIFAR-10, $\lambda=0.1$). Each point shows mean clean accuracy at epoch 200 across 3 seeds. Monotonically decreasing accuracy with increasing $\rho$ confirms the NHT prediction: stronger shortcuts produce larger norm gaps ($V_\mathrm{sc}/V_\mathrm{st}$), delaying the transition and ultimately preventing it entirely at $\rho=1.0$ (shortcut reliance $=1.0$, clean accuracy $=10.2\%$). The shaded band marks the intermediate-regime operating point ($\rho=0.95$) used in Experiments A and C.
  • Figure 3: Representation phase diagram over the $(\lambda, \rho)$ plane (CIFAR-10, 28 runs). Colour encodes final clean accuracy; contour lines separate the four predicted regimes. Shortcut-dominated (bottom-right, high $\rho$, low $\lambda$): network stays on $\mathcal{M}_\mathrm{sc}$; clean accuracy $\leq 15\%$. NHT regime (centre): delayed transition produces clean accuracy $55$--$78\%$. Structured (top-left, low $\rho$, moderate $\lambda$): network reaches $\mathcal{M}_\mathrm{st}$ directly; clean accuracy $\geq 75\%$. Suppressed (top-right, high $\lambda$): weight decay overwhelms learning; accuracy $\leq 44\%$. The phase boundary between shortcut-dominated and NHT regimes sharpens with increasing $\rho$, consistent with the norm-gap prediction $\Delta V \propto \rho$.
  • Figure 4: Layer-wise norm dynamics reveal a backward representational transition (CIFAR-10, $\lambda=0.1$, $\rho=0.95$).(a) Per-layer $\|\theta^{(\ell)}\|^2$ normalised to peak value. The classification head (fc, purple) reaches its peak at epoch ${\approx}60$ and contracts $45\%$ by epoch 200; conv1 (blue) peaks later and contracts only $31\%$; intermediate layers fall between. (b) Total norm $\|\theta\|^2$ (black) grows throughout, masking the layer-level contraction in head layers. (c) Ratio $\|\theta^{(L)}\|^2 / \|\theta^{(1)}\|^2$ decreases monotonically after epoch 60, providing a layer-ratio diagnostic that detects the transition when total norm monitoring would fail. Shaded band: $\pm1$ std over 4 seeds. These results directly confirm Proposition \ref{['prop:layerwise']}: layers with higher shortcut encoding capacity $\alpha_\ell$ escape the shortcut manifold faster, producing a backward transition from output to input.
  • Figure 5: CelebA norm-hierarchy validation across three regularisation regimes (6 runs: 3 $\lambda$ values $\times$ 2 seeds).(a) Parameter norm $\|\theta\|^2$: The three regimes are clearly separated --- weak $\lambda=0.001$ (blue) grows monotonically to ${\approx}20{,}000$; intermediate $\lambda=0.1$ (orange) plateaus near $3{,}000$ with non-monotone dynamics (5/20 intervals show norm decrease); strong $\lambda=1.0$ (red) is suppressed below $1{,}000$ from epoch 1. Norm ratio between weak and strong regimes: $37\times$, confirming P1. (b) Worst-group accuracy: All three regimes show stable performance above $86\%$ after epoch 20, with the predicted monotone ordering $89.1\% > 88.0\% > 87.0\%$ (P3 confirmed). The absence of a sharp accuracy jump at intermediate $\lambda$ is consistent with Scenario C (no clean norm separation, $S = -0.11$): there is no delayed transition to observe. (c) Average accuracy: Similarly ordered ($90.9\% > 90.5\% > 90.2\%$), with larger variance under strong regularisation (std $= 1.2\%$ vs. $0.1\%$). The fc/conv$_1$ norm ratio increases from $0.65\times$ to $2.33\times$ as $\lambda$ increases (P4 confirmed, not shown), providing the first empirical confirmation of Proposition \ref{['prop:layerwise']} on a real face-attribute dataset.

Theorems & Definitions (21)

  • Definition 2.1: Multi-Representation Interpolation
  • Definition 2.2: Norm Hierarchy
  • Remark 2.3: Why shortcut solutions tend to have larger norm
  • Definition 2.4: Shortcut Accessibility
  • Remark 2.5: On Shortcut Accessibility
  • Remark 2.6: Instantiations
  • Theorem 3.1: Generalised Escape under Regularisation
  • proof
  • Remark 3.2: Intuition
  • Remark 3.3: Robustness to idealised assumptions
  • ...and 11 more