Table of Contents
Fetching ...

Memory Mosaics

Jianyu Zhang, Niklas Nolte, Ranajoy Sadhukhan, Beidi Chen, Léon Bottou

TL;DR

Memory Mosaics address interpretability and compositional learning in sequence modeling by organizing multiple associative memories that retrieve via Gaussian kernel smoothing. The approach reframes attention as kernel-based retrieval and introduces predictive disentanglement, a training-time decomposition that assigns sub-tasks to individual memories. The paper shows that Memory Mosaics match the i.i.d. performance of decoding transformers on language modeling and can outperform them on out-of-distribution tasks such as in-context learning, with demonstrations on a toy three-moons problem and medium-scale language modeling. This work suggests a principled, interpretable alternative to fully attention-based models and highlights memory-based architectures as a promising path for scalable, modular, and transparent sequence learning.

Abstract

Memory Mosaics are networks of associative memories working in concert to achieve a prediction task of interest. Like transformers, memory mosaics possess compositional capabilities and in-context learning capabilities. Unlike transformers, memory mosaics achieve these capabilities in comparatively transparent way ("predictive disentanglement"). We illustrate these capabilities on a toy example and also show that memory mosaics perform as well or better than transformers on medium-scale language modeling tasks.

Memory Mosaics

TL;DR

Memory Mosaics address interpretability and compositional learning in sequence modeling by organizing multiple associative memories that retrieve via Gaussian kernel smoothing. The approach reframes attention as kernel-based retrieval and introduces predictive disentanglement, a training-time decomposition that assigns sub-tasks to individual memories. The paper shows that Memory Mosaics match the i.i.d. performance of decoding transformers on language modeling and can outperform them on out-of-distribution tasks such as in-context learning, with demonstrations on a toy three-moons problem and medium-scale language modeling. This work suggests a principled, interpretable alternative to fully attention-based models and highlights memory-based architectures as a promising path for scalable, modular, and transparent sequence learning.

Abstract

Memory Mosaics are networks of associative memories working in concert to achieve a prediction task of interest. Like transformers, memory mosaics possess compositional capabilities and in-context learning capabilities. Unlike transformers, memory mosaics achieve these capabilities in comparatively transparent way ("predictive disentanglement"). We illustrate these capabilities on a toy example and also show that memory mosaics perform as well or better than transformers on medium-scale language modeling tasks.
Paper Structure (35 sections, 9 equations, 14 figures, 5 tables)

This paper contains 35 sections, 9 equations, 14 figures, 5 tables.

Figures (14)

  • Figure 1: Elementary memory unit. The keys $k_T$ are computed as a function of past observations $(x_t)_{t\leq T}$. The values $v_T$ peek into the future. In this example, the value also depend on the next observation $x_{T+1}$. At time $T$, the associative memory uses the known key $k_T$ to compute an estimate $y_T$ of $\mathbb{E}(v_T|k_T)$ using only the previously stored pairs $(k_t,v_t)$, $t<T$. One time step later, the input $x_{T+1}$ is revealed, the value $v_T$ can be computed, and the pair $(k_T,v_T)$ is added to the memory.
  • Figure 2: The curve plots the prediction losses for all training sequence indices $t\in\{1\dots D\}$ in the training sequence. Minimizing their sum ---the area under the curve--- favors memories that produce useful value estimates after fewer time steps.
  • Figure 3: An architecture for the three moons problem. We consider single-layer networks with either $N_h=1$ or $N_h=3$ memory units whose keys and values belong to either $\mathbb{C}^3$ ($N_h=1$) or $\mathbb{C}^1$ ($N_h=3$). Both nets have $3\times3\times2\times3=54$ trainable real parameters that determine how to predict the moon positions using either a single 6-dimensional memory or three 2-dimensional memories.
  • Figure 4:
  • Figure 6: Left: Classic GPT2-small transformer. Right: GPT2-like Memory Mosaic
  • ...and 9 more figures