Table of Contents
Fetching ...

Composing Global Solutions to Reasoning Tasks via Algebraic Objects in Neural Nets

Yuandong Tian

TL;DR

The paper studies 2-layer neural networks with quadratic activation trained to predict Abelian-group products under the $L_2$ loss and unveils a semi-ring structure in the weight space together with sum potentials that are ring homomorphisms. This enables analytical construction of global solutions from partial ones, yielding Fourier-based global solutions of per-frequency order $4$ (i.e., $2\times2$) and order $6$ (i.e., $2\times3$), as well as a global perfect memorization solution of order $d^2$. Empirically, about 95% of gradient-descent solutions align with the theory and can be factorized into the proposed components; overparameterization helps training by decoupling SP dynamics, while weight decay biases toward simpler, low-order solutions. The framework suggests a paradigm shift from gradient-based optimization to loss-decomposition and algebraic composition, with implications for reasoning tasks and broader group-action settings.

Abstract

We prove rich algebraic structures of the solution space for 2-layer neural networks with quadratic activation and $L_2$ loss, trained on reasoning tasks in Abelian group (e.g., modular addition). Such a rich structure enables \emph{analytical} construction of global optimal solutions from partial solutions that only satisfy part of the loss, despite its high nonlinearity. We coin the framework as CoGS (\emph{\underline{Co}mposing \underline{G}lobal \underline{S}olutions}). Specifically, we show that the weight space over different numbers of hidden nodes of the 2-layer network is equipped with a semi-ring algebraic structure, and the loss function to be optimized consists of \emph{sum potentials}, which are ring homomorphisms, allowing partial solutions to be composed into global ones by ring addition and multiplication. Our experiments show that around $95\%$ of the solutions obtained by gradient descent match exactly our theoretical constructions. Although the global solutions constructed only required a small number of hidden nodes, our analysis on gradient dynamics shows that overparameterization asymptotically decouples training dynamics and is beneficial. We further show that training dynamics favors simpler solutions under weight decay, and thus high-order global solutions such as perfect memorization are unfavorable. The code is open sourced at https://github.com/facebookresearch/luckmatters/tree/yuandong3/ssl/real-dataset.

Composing Global Solutions to Reasoning Tasks via Algebraic Objects in Neural Nets

TL;DR

The paper studies 2-layer neural networks with quadratic activation trained to predict Abelian-group products under the loss and unveils a semi-ring structure in the weight space together with sum potentials that are ring homomorphisms. This enables analytical construction of global solutions from partial ones, yielding Fourier-based global solutions of per-frequency order (i.e., ) and order (i.e., ), as well as a global perfect memorization solution of order . Empirically, about 95% of gradient-descent solutions align with the theory and can be factorized into the proposed components; overparameterization helps training by decoupling SP dynamics, while weight decay biases toward simpler, low-order solutions. The framework suggests a paradigm shift from gradient-based optimization to loss-decomposition and algebraic composition, with implications for reasoning tasks and broader group-action settings.

Abstract

We prove rich algebraic structures of the solution space for 2-layer neural networks with quadratic activation and loss, trained on reasoning tasks in Abelian group (e.g., modular addition). Such a rich structure enables \emph{analytical} construction of global optimal solutions from partial solutions that only satisfy part of the loss, despite its high nonlinearity. We coin the framework as CoGS (\emph{\underline{Co}mposing \underline{G}lobal \underline{S}olutions}). Specifically, we show that the weight space over different numbers of hidden nodes of the 2-layer network is equipped with a semi-ring algebraic structure, and the loss function to be optimized consists of \emph{sum potentials}, which are ring homomorphisms, allowing partial solutions to be composed into global ones by ring addition and multiplication. Our experiments show that around of the solutions obtained by gradient descent match exactly our theoretical constructions. Although the global solutions constructed only required a small number of hidden nodes, our analysis on gradient dynamics shows that overparameterization asymptotically decouples training dynamics and is beneficial. We further show that training dynamics favors simpler solutions under weight decay, and thus high-order global solutions such as perfect memorization are unfavorable. The code is open sourced at https://github.com/facebookresearch/luckmatters/tree/yuandong3/ssl/real-dataset.
Paper Structure (22 sections, 42 theorems, 90 equations, 12 figures, 3 tables)

This paper contains 22 sections, 42 theorems, 90 equations, 12 figures, 3 tables.

Key Result

Theorem 1

The objective of 2-layer MLP network with quadratic activation can be written as $\ell = d^{-1}\sum_{k\neq 0} \ell_k + (d-1)/d$, where Here $r_{k_1k_2k} := \sum_j z_{a k_1 j} z_{b k_2 j} z_{ckj}$ and $r_{pk_1k_2k} := \sum_j z_{pk_1j} z_{pk_2j} z_{ckj}$.

Figures (12)

  • Figure 1: Overview of proposed theoretical framework CoGS. (1) The family of 2-layer neural networks, $\mathcal{Z}$, form a semi-ring algebraic structure (Theorem \ref{['thm:semi-ring']}) with ring addition and multiplication (Def. \ref{['def:operationsinz']}). $\mathcal{Z} = \bigcup_{q\ge 0} \mathcal{Z}_q$ where $\mathcal{Z}_q$ is a collection of all weights with order-$q$ (i.e., $q$ hidden nodes). (2) For outcome prediction of Abelian group multiplication, the MSE loss $\ell({\bm{z}})$ is a function of sum potentials (SPs) $r_{k_1k_2k}({\bm{z}})$ and $r_{pk_1k_2k}({\bm{z}})$ (Theorem \ref{['thm:analyticform']}), which are ring homomorphisms (Theorem \ref{['thm:pothomo']}). (3) Thanks to the property of ring homomorphism, global solutions to MSE loss $\ell({\bm{z}})$ with quadratic activation can be constructed algebraically from partial solutions that only satisfy a subset of constraints (Sec. \ref{['sec:composing-solutions']}) using ring addition and multiplication, instead of running gradient descent. Examples include Fourier solution ${\bm{z}}_{F6}$ (Corollary \ref{['co:order-6']}) and ${\bm{z}}_{F4/6}$ (Corollary \ref{['co:order-4']}) and perfect memorization solution ${\bm{z}}_M$ (Corollary \ref{['co:perfectmem']}). In Sec. \ref{['sec:gradientdynamics']}, we analyze the role played of SPs in gradient dynamics, showing that the dynamics favors low-order global solutions (Theorem \ref{['thm:loworderfirst']}) under weight decay regularization, and the dynamics of SPs become decoupled with infinite width (Theorem \ref{['lemma:infinitem']}). In Sec. \ref{['sec:exp']} we show that the gradient descent solutions match exactly with our theoretical construction.
  • Figure 2: Solutions obtained by Adam optimizers on $\ell_2$ loss for modular addition task with $|G|\!=\! d\! =\! 7$ and $q\! =\! 20$ hidden nodes. Top: For each frequency $\pm k$, exactly $6$ hidden nodes exist (Corollary \ref{['co:order-6']}). Bottom: Optimizing Eqn. \ref{['eq:obj']} without the last term $\sum_{m\neq 0} \sum_{p\in \{a,b\}} |\sum_{k'} r_{p,k',m-k',k}|^2$ (i.e., without constraint $R_*$). Now each frequency has exactly $3$ hidden nodes, corresponding to the solution ${\bm{z}}_{\mathrm{syn}} = \boldsymbol{\rho}({\bm{u}}_{\mathrm{syn}})$ in Tbl. \ref{['tab:poly-construction']}.
  • Figure 3: Dynamics of sum potentials (SPs) over the training process for modular addition with $d = 23$ and $q = 1024$ hidden nodes. Top Row.Left: Training/test accuracy reaches 100% and loss close to $0$. Test accuracy jumps after training reaches 100% (grokking). Mid: After 10k epochs, the distribution of solution orders are concentrated at 4 and 6 (Corollary \ref{['co:order-6']} and \ref{['co:order-4']}). Right: Dynamics of $r_{k_1k_2k}$. Summation of diagonal $r_{kkk}$ converges towards $d-1$ (dotted line) with ripple effects, while off-diagonal $r_{k_1k_2k}$ converges towards $0$. Bottom Row. Dynamics of different SPs. Order-4 and order-6 behave differently on $r_{p,k,-k,k}$, because order-4 does not satisfy the sufficient condition (Lemma \ref{['co:globalminimizer']}) but a mixture of order-4 and order-6 (i.e., ${\bm{z}}_{F4/6}$) is still the global solution to the $L_2$ loss (Corollary \ref{['co:foursixsol']}).
  • Figure 4: Solution distribution (accumulated over 5 random seeds) over different weight decay regularization for $q = 512$, trained with 10k epochs with Adam with learning rate $0.01$ on modular addition (i.e., predicting $a+b\mod d$) with $d\in \{23,71,127\}$. Red dashed lines correspond to order-4/6 solutions.
  • Figure 5: Visualization of $\hat{{\bm{z}}}^{(k_0)}_{F6}$.
  • ...and 7 more figures

Theorems & Definitions (73)

  • Theorem 1: Analytic form of $L_2$ loss with quadratic activation
  • Definition 1: 0/1-set
  • Lemma 1: A Sufficient Conditions of Global Solutions of Eqn. \ref{['eq:obj']}
  • Definition 2: Order of ${\bm{z}}$
  • Definition 3: Scalar multiplication
  • Definition 4: Identification of $\mathcal{Z}$
  • Definition 5: Addition and Multiplication in $\mathcal{Z}$
  • Theorem 2: Algebraic Structure of $\cZ$
  • Definition 6: Sum potential (SP)
  • Theorem 3
  • ...and 63 more