Table of Contents
Fetching ...

Learning Associative Memories with Gradient Descent

Vivien Cabannes, Berfin Simsek, Alberto Bietti

TL;DR

This work reframes the training of a linear associative memory layer as an interacting-particle system, linking gradient dynamics to embeddings, data distributions, and memory interference. It shows that in overparameterized regimes with orthogonal embeddings margins grow logarithmically, while imbalanced token frequencies and correlated embeddings induce oscillations and potential loss spikes, especially at large learning rates; in underparameterized regimes cross-entropy can misallocate capacity, leading to suboptimal memorization. The authors provide exact analyses for orthogonal settings, illustrate two-particle interference phenomena, and validate key insights on small Transformer-like models, highlighting how data geometry governs convergence speed and stability. The findings offer a fine-grained lens on training dynamics relevant to understanding, diagnosing, and potentially improving learning in larger neural networks. Overall, the study connects associative-memory training with concrete dynamical phenomena (oscillations, spikes, and memoization limits) that can inform optimization choices and architectural design in practice.

Abstract

This work focuses on the training dynamics of one associative memory module storing outer products of token embeddings. We reduce this problem to the study of a system of particles, which interact according to properties of the data distribution and correlations between embeddings. Through theory and experiments, we provide several insights. In overparameterized regimes, we obtain logarithmic growth of the ``classification margins.'' Yet, we show that imbalance in token frequencies and memory interferences due to correlated embeddings lead to oscillatory transitory regimes. The oscillations are more pronounced with large step sizes, which can create benign loss spikes, although these learning rates speed up the dynamics and accelerate the asymptotic convergence. In underparameterized regimes, we illustrate how the cross-entropy loss can lead to suboptimal memorization schemes. Finally, we assess the validity of our findings on small Transformer models.

Learning Associative Memories with Gradient Descent

TL;DR

This work reframes the training of a linear associative memory layer as an interacting-particle system, linking gradient dynamics to embeddings, data distributions, and memory interference. It shows that in overparameterized regimes with orthogonal embeddings margins grow logarithmically, while imbalanced token frequencies and correlated embeddings induce oscillations and potential loss spikes, especially at large learning rates; in underparameterized regimes cross-entropy can misallocate capacity, leading to suboptimal memorization. The authors provide exact analyses for orthogonal settings, illustrate two-particle interference phenomena, and validate key insights on small Transformer-like models, highlighting how data geometry governs convergence speed and stability. The findings offer a fine-grained lens on training dynamics relevant to understanding, diagnosing, and potentially improving learning in larger neural networks. Overall, the study connects associative-memory training with concrete dynamical phenomena (oscillations, spikes, and memoization limits) that can inform optimization choices and architectural design in practice.

Abstract

This work focuses on the training dynamics of one associative memory module storing outer products of token embeddings. We reduce this problem to the study of a system of particles, which interact according to properties of the data distribution and correlations between embeddings. Through theory and experiments, we provide several insights. In overparameterized regimes, we obtain logarithmic growth of the ``classification margins.'' Yet, we show that imbalance in token frequencies and memory interferences due to correlated embeddings lead to oscillatory transitory regimes. The oscillations are more pronounced with large step sizes, which can create benign loss spikes, although these learning rates speed up the dynamics and accelerate the asymptotic convergence. In underparameterized regimes, we illustrate how the cross-entropy loss can lead to suboptimal memorization schemes. Finally, we assess the validity of our findings on small Transformer models.
Paper Structure (27 sections, 5 theorems, 141 equations, 7 figures)

This paper contains 27 sections, 5 theorems, 141 equations, 7 figures.

Key Result

Theorem 1

Define the particle $w_{ij}$, as well as the constant correlation parameters The projected gradient can be rewritten as Hence, all variations of gradient dynamics, eq:GF, eq:GD, eq:SGF and eq:SGD, can be expressed as a (stochastic) system of interacting particles. For example, the gradient descent dynamics eq:GD is Similarly, the dynamics for the stochastic gradient descent consists in replaci

Figures (7)

  • Figure 1: Level lines of ${\mathcal{L}}(W)$ for $N=d=2$ as a function of $\gamma_i(W):= (u_2-u_1)^\top Wf_i$ where $(f_i)$ is a basis of ${\mathbb{R}}^2$. Token embeddings have correlation $\alpha$\ref{['eq:correlation-binary']}. We equally plot the value of ${\mathcal{L}}_{01}(W)$, dark blue meaning perfect accuracy, and white meaning null accuracy.
  • Figure 2: Loss spikes. Trajectories of $W_t$ in the setting of Figure \ref{['fig:level']} for two learning rates $\eta$, $\eta=10$ in green, $\eta=1$ in red, and their traces in term of losses as a function of the number of epochs, here $t\in[35]$.
  • Figure 3: Level lines of the (logarithm of the) number of steps needed to reach perfect accuracy in the setting of Theorem \ref{['thm:two-flow']}, as a function of the learning rates $\eta$, the interaction parameter $\alpha$, and the class imbalance $\log(p_1 / p_2)$. Red means more steps to reach perfect accuracy.
  • Figure 4: Forgetting. Similar plots as in Figures \ref{['fig:level']} and \ref{['fig:translate']}, yet in the limited capacity case $d < N$. In those situations, competition between the memories can lead to sub-optimal minimizer of ${\mathcal{L}}$, which we illustrate with SGD on the bottom plots. The sub-optimality is reflected in the excess of risk ${\mathcal{E}} = {\mathcal{L}}_{01}(\mathop{\mathrm{arg\,min}}\limits_W {\mathcal{L}}(W)) - \min_W{\mathcal{L}}_{01}(W)$.
  • Figure 5: Sharpness profile. Gradient descent trajectories in the setting of Figures \ref{['fig:translate']} and \ref{['fig:forgetting']} with learning rates $\eta=10$ (green) and $\eta=1$ (red). We plot the level lines of the sharpness, i.e. the operator norm of $\nabla^2 {\mathcal{L}}(W)$, as well as the trace of the trajectories in terms of sharpness. The left plots are in the overparameterized regime, the right ones in the underparameterized one.
  • ...and 2 more figures

Theorems & Definitions (6)

  • Theorem 1: Particle system
  • proof
  • Theorem 2: Binary orthogonal
  • Theorem 3: Multi-class orthogonal
  • Theorem 4: Two particles interacting
  • Proposition 5: Loss spikes