Learning Associative Memories with Gradient Descent
Vivien Cabannes, Berfin Simsek, Alberto Bietti
TL;DR
This work reframes the training of a linear associative memory layer as an interacting-particle system, linking gradient dynamics to embeddings, data distributions, and memory interference. It shows that in overparameterized regimes with orthogonal embeddings margins grow logarithmically, while imbalanced token frequencies and correlated embeddings induce oscillations and potential loss spikes, especially at large learning rates; in underparameterized regimes cross-entropy can misallocate capacity, leading to suboptimal memorization. The authors provide exact analyses for orthogonal settings, illustrate two-particle interference phenomena, and validate key insights on small Transformer-like models, highlighting how data geometry governs convergence speed and stability. The findings offer a fine-grained lens on training dynamics relevant to understanding, diagnosing, and potentially improving learning in larger neural networks. Overall, the study connects associative-memory training with concrete dynamical phenomena (oscillations, spikes, and memoization limits) that can inform optimization choices and architectural design in practice.
Abstract
This work focuses on the training dynamics of one associative memory module storing outer products of token embeddings. We reduce this problem to the study of a system of particles, which interact according to properties of the data distribution and correlations between embeddings. Through theory and experiments, we provide several insights. In overparameterized regimes, we obtain logarithmic growth of the ``classification margins.'' Yet, we show that imbalance in token frequencies and memory interferences due to correlated embeddings lead to oscillatory transitory regimes. The oscillations are more pronounced with large step sizes, which can create benign loss spikes, although these learning rates speed up the dynamics and accelerate the asymptotic convergence. In underparameterized regimes, we illustrate how the cross-entropy loss can lead to suboptimal memorization schemes. Finally, we assess the validity of our findings on small Transformer models.
