Table of Contents
Fetching ...

Scaling Laws for Associative Memories

Vivien Cabannes, Elvis Dohmatob, Alberto Bietti

TL;DR

This work studies associative memory in a high-dimensional, transformer-like setting where memories are stored as outer products of input and output embeddings. It derives scaling laws for the generalization error as a function of model capacity $d$ and data size $T$ under Zipf-distributed data, compares memory-storage schemes, and analyzes optimization-driven memorization via SGD, Adam, and layer normalization. Key findings show explicit error bounds for random embeddings, the dramatic capacity gains achievable by learning embeddings, and practical guidance on step size, batch size, and normalization to optimize memory storage. The results offer theoretical and empirical insights into memorization in deep networks and suggest design principles for memory-augmented models and training strategies.

Abstract

Learning arguably involves the discovery and memorization of abstract rules. The aim of this paper is to study associative memory mechanisms. Our model is based on high-dimensional matrices consisting of outer products of embeddings, which relates to the inner layers of transformer language models. We derive precise scaling laws with respect to sample size and parameter size, and discuss the statistical efficiency of different estimators, including optimization-based algorithms. We provide extensive numerical experiments to validate and interpret theoretical results, including fine-grained visualizations of the stored memory associations.

Scaling Laws for Associative Memories

TL;DR

This work studies associative memory in a high-dimensional, transformer-like setting where memories are stored as outer products of input and output embeddings. It derives scaling laws for the generalization error as a function of model capacity and data size under Zipf-distributed data, compares memory-storage schemes, and analyzes optimization-driven memorization via SGD, Adam, and layer normalization. Key findings show explicit error bounds for random embeddings, the dramatic capacity gains achievable by learning embeddings, and practical guidance on step size, batch size, and normalization to optimize memory storage. The results offer theoretical and empirical insights into memorization in deep networks and suggest design principles for memory-augmented models and training strategies.

Abstract

Learning arguably involves the discovery and memorization of abstract rules. The aim of this paper is to study associative memory mechanisms. Our model is based on high-dimensional matrices consisting of outer products of embeddings, which relates to the inner layers of transformer language models. We derive precise scaling laws with respect to sample size and parameter size, and discuss the statistical efficiency of different estimators, including optimization-based algorithms. We provide extensive numerical experiments to validate and interpret theoretical results, including fine-grained visualizations of the stored memory associations.
Paper Structure (46 sections, 8 theorems, 128 equations, 16 figures, 1 table)

This paper contains 46 sections, 8 theorems, 128 equations, 16 figures, 1 table.

Key Result

Proposition 1

Consider a infinite memory model $\hat{f}$, which at time $T$ predicts correctly all $x$ that were seen in the past training, i.e., $x\in{X_t}_{t\in[T]}$, where the $(X_t, Y_t)$ were drawn independently at random from a distribution $p\in\Delta_{[N]\times[M]}$. Under the data model the generalizatio Here, the notation $a \asymp b$ means that there exist two constants $c_1$ and $c_2$ such that $c_1

Figures (16)

  • Figure 1: Scaling laws with respect to model capacity $d$ (left), respectively the number of data seen $T$ (right), for various numbers of dataset size $T$, respectively various model capacity $d$. This plots validates empirically the theory developed in the paper that proves scaling laws in ${\mathcal{E}}(f_q) \asymp d^{-\alpha+1} + T^{-1+1/\alpha}$ (dashed lines) under our setting with $\alpha=2$\ref{['eq:data']}, \ref{['eq:model']}, \ref{['eq:loss']}, and the association scheme \ref{['eq:thres']} with $\rho=0$ and $P=d/8$. The experiments averaged over $100$ runs, standard deviations are shown with solid color.
  • Figure 2: Error due to finite memory capacity: the stacking of associative memories in a matrix $W$ may exhibit a pattern $W = \sum_{x} u_{f_*(x)}e_x^\top$ where three inputs mapped to three different outputs interact in such a way that $u_2^\top W e_1 = e_2^\top e_1 + u_2^\top u_3 e_3^\top e_1 \geq 1 + u_1^\top u_3 e_3^\top e_1 = u_1^\top W e_1$, so that $f_W(x=1) = 2 \neq 1 = f_*(x=1)$. In other terms, memory interference may lead to wrong prediction, illustrating the finite capacity of the model $f_W$\ref{['eq:model']} to store all data associations.
  • Figure 3: Generalization error \ref{['eq:loss']} as a function of $d$ and $T$ for the model \ref{['eq:ass-mem']} averaged over $100$ runs. The data follows a Zipf law with $\alpha=0.5$, $N=100$, $M=5$ and $f_*(x) = x \operatorname{ mod. }M$. Left: error for $q_0$\ref{['eq:fill']}, either $d$ is too small and there will be memory overflow leading to large error (red area), either it is big enough and with enough data, the error will be null (blue area). Middle: error for $q_1$\ref{['eq:gen']}, for small $d$ and big $T$, it avoid memory overflow allowing a smaller error then $q_0$; however for big $d$ it does not allocated enough memory to rare association, leading to a bigger error. Those results can be interpreted mechanistically by looking at the corresponding memory matrices (see \ref{['fig:theory-val-mat']}). Right: Generalization error when $T=+\infty$, $N = 100$ and $\alpha=2$: the scheme $q_0$ leads to a zero-one type of plot where if $d < N$ the error is high, and if $d > N$ the error decreases fast to zero (in blue); the scheme $q_1$ leads to an error decreasing in $d^{-(\alpha-1) / 2\alpha} = d^{-1/4}$ as predicted by theory (in orange); the scheme $q_{0, P}$\ref{['eq:thres']} with $P = d / 8$, decreases in $d^{-(\alpha-1)} = d^{-1}$ until reaching the tipping point when $d/8 > N$ (in green).
  • Figure 4: Comparison between the error found by optimizing $W$\ref{['eq:model']} with SGD on the cross-entropy loss, and its approximation with $q(x)$\ref{['eq:ass-mem']} and the approximate update rule \ref{['eq:sgd-ass-mem']}. We consider $N=100$, $M=5$, $f_*(x) = x\operatorname{mod.}M$, $\alpha=2$, and batch size equals one. Left: One run with $d=N=100$ with $\gamma=10$. Middle: Average over 100 runs with $d=N=100$ with $\gamma=1$. Right: Average when $d=N/10=10$ with $\gamma=1$, which implies that our approximation is not valid anymore. The same results can be obtained for bigger batch sizes as shown in \ref{['fig:sgd-approxextra']}.
  • Figure 5: Theoretical approximation of the association scheme found with stochastic gradient descent with batch size equals one and fixed learning rates. Left: Plot of $f^{n}(0)$ as a function of $n$ where $f$ is the effect of one gradient update on $q(x)$\ref{['eq:sgd-ass-mem']}. Right: Plot of the resulting $q_\gamma(x)$ when $n_x \propto p(x) \propto (x+3)^{-\alpha}$ with $\alpha=2$ and $n_N = 1$. In dashed, we represent $q_\rho$\ref{['eq:gen']} for $\rho=0.05$, $\rho=0.35$ and $\rho=1$. Those curves map well $q_\gamma$ for $\gamma=10$, $\gamma=10^{-1}$ and $\gamma=10^{-3}$ respectively.
  • ...and 11 more figures

Theorems & Definitions (10)

  • Proposition 1: Finite data, infinite memory
  • Theorem 1: Infinite data, finite memory
  • Proposition 2: Without thresholding
  • Proposition 3: With thresholding
  • Theorem 2: Minimax performance
  • Proposition 4: Finite data and finite memory
  • Definition 1: Quasi-orthogonality
  • Lemma 1
  • Proposition 5: Improvement for learned inputs embeddings
  • proof