Table of Contents
Fetching ...

Dense associative memory for Gaussian distributions

Chandan Tankala, Krishnakumar Balasubramanian

TL;DR

This work extends dense associative memories to operate on Gaussian distributions by endowing the space of Gaussians with the Bures–Wasserstein metric and defining a log-sum-exp energy over stored distributions. Retrieval is performed via Gibbs-weighted aggregation of optimal transport maps, with fixed points corresponding to Wasserstein barycenters, enabling distributional recall. The authors prove exponential storage capacity and retrieval guarantees under Wasserstein perturbations and validate BW-DAM on synthetic data and real-world Gaussian embeddings for images, text, and sentences, demonstrating superior robustness and accuracy to Euclidean baselines. Overall, the framework bridges classical DAMs with modern distributional representations, opening avenues for uncertainty-aware memory and probabilistic reasoning in memory-augmented learning.

Abstract

Dense associative memories (DAMs) store and retrieve patterns via energy-function based fixed points, but existing models are limited to vector representations. We extend DAMs to Gaussian densities equipped with the 2-Wasserstein distance. Our framework defines a log-sum-exp energy over stored distributions and a retrieval dynamics aggregating optimal transport maps in a Gibbs-weighted manner. Stationary points correspond to self-consistent Wasserstein barycenters, generalizing classical DAM fixed points. We prove exponential storage capacity and provide quantitative retrieval guarantees under Wasserstein perturbations. We validate the method on synthetic and real-world image (CelebA and CIFAR-10 datasets) and text (text8 and NLI corpus) datasets. By generalizing from vectors to distributions, our work bridges classical DAMs with modern generative modeling and paves way for distributional storage and retrieval in memory-augmented learning.

Dense associative memory for Gaussian distributions

TL;DR

This work extends dense associative memories to operate on Gaussian distributions by endowing the space of Gaussians with the Bures–Wasserstein metric and defining a log-sum-exp energy over stored distributions. Retrieval is performed via Gibbs-weighted aggregation of optimal transport maps, with fixed points corresponding to Wasserstein barycenters, enabling distributional recall. The authors prove exponential storage capacity and retrieval guarantees under Wasserstein perturbations and validate BW-DAM on synthetic data and real-world Gaussian embeddings for images, text, and sentences, demonstrating superior robustness and accuracy to Euclidean baselines. Overall, the framework bridges classical DAMs with modern distributional representations, opening avenues for uncertainty-aware memory and probabilistic reasoning in memory-augmented learning.

Abstract

Dense associative memories (DAMs) store and retrieve patterns via energy-function based fixed points, but existing models are limited to vector representations. We extend DAMs to Gaussian densities equipped with the 2-Wasserstein distance. Our framework defines a log-sum-exp energy over stored distributions and a retrieval dynamics aggregating optimal transport maps in a Gibbs-weighted manner. Stationary points correspond to self-consistent Wasserstein barycenters, generalizing classical DAM fixed points. We prove exponential storage capacity and provide quantitative retrieval guarantees under Wasserstein perturbations. We validate the method on synthetic and real-world image (CelebA and CIFAR-10 datasets) and text (text8 and NLI corpus) datasets. By generalizing from vectors to distributions, our work bridges classical DAMs with modern generative modeling and paves way for distributional storage and retrieval in memory-augmented learning.

Paper Structure

This paper contains 17 sections, 8 theorems, 146 equations, 8 figures, 2 algorithms.

Key Result

Theorem 1

Let $0 < p < 1$ be a constant, and let $\lambda_{\min}, \lambda_{\mathsf{max}}, \kappa, \alpha$ be constants as defined in Assumption Assumption:SeparationCondition. Consider a Wasserstein sphere $\mathcal{S}_R$ of radius $R = \sqrt{2d\lambda_{\mathsf{max}}(2 + \log \kappa)}$ centered at $\delta_0$,

Figures (8)

  • Figure 1: Comparison of BW-DAM and Euclidean-DAM dynamics for retrieving stored Gaussian measures. Each row corresponds to perturbing a different stored Gaussian $X_i$ ($i = 1, \ldots, 5$). Column 1: sampled Gaussians with the target highlighted. Columns 2-3: retrieval dynamics showing the query $\xi$ (green), fixed point $\xi^\ast$ (red), and target (dashed). Parameters: $d = 2$, $N = 5$, $\lambda_{\min} = 0.8$, $\lambda_{\mathsf{max}} = 1.0$, $\beta = 2$, $W_2(\xi, X_i) = 0.5r$ where $r = \sqrt{\lambda_{\min}}$.
  • Figure 2: Energy landscape $E(\xi)$ (Equation \ref{['Eq:EnergyFunctionalDefinition']}) for query one-dimensional Gaussians $\xi = \mathcal{N}(\mu, \sigma^2)$ evaluated on a $200 \times 200$ grid with $\mu \in [-4, 4]$ and $\sigma \in [0.01, 2]$. Red dots indicate $N = 5$ stored Gaussian measures sampled uniformly at random with means in $[-3, 3]$ and standard deviations in $[0.2, 1.0]$. As $\beta$ increases from $0.1$ to $1000$, the energy transitions from nearly flat with overlapping basins to sharp, well-separated minima.
  • Figure 3: Memory retrieval on CelebA images. (Top) Qualitative results for five randomly chosen images with $20\%$ of pixels masked in gray ($\beta = 0.6$). Rows from top to bottom: (1) original images, (2) masked images, (3) BW-DAM retrieval, (4) Eu-DAM retrieval, (5) Pixel-DAM retrieval. (Bottom) Retrieval accuracy as a function of inverse temperature $\beta$ for BW-DAM, Eu-DAM, and Pixel-DAM, evaluated on 100 randomly chosen images. Error bars indicate standard deviation over 5 trials of 100 randomly chosen images each.
  • Figure 4: BW-DAM retrieval on Word2Gauss embeddings trained on text8 corpus. Each word's Gaussian embedding is perturbed by $W_2$ distance $\sqrt{\lambda_{\min}}$ and recovered using Algorithm \ref{['Alg:GaussianRetrieval']}. (a) Word evolution showing the nearest word to the current iterate at each step for $\beta = 10$; Iter 0 corresponds to the perturbed query before any updates. (b) Retrieval accuracy as a function of inverse temperature $\beta$.
  • Figure 5: Convergence of BW-DAM retrieval dynamics for varying dimension $d$ and temperature $\beta$. We sample $N=1000$ Gaussians from a Wasserstein sphere of radius $R=\sqrt{2d}$ with eigenvalues in $[0.8, 1.2]$, perturb $75\%$ of the Gaussians to distance $W_2 = \sqrt{\lambda_{\min}}$, and run the dynamics in Algorithm \ref{['Alg:GaussianRetrieval']}. Shaded regions show $\pm 1$ standard deviation.
  • ...and 3 more figures

Theorems & Definitions (21)

  • Definition 1: Storage of a Gaussian measure
  • Theorem 1
  • Remark 1
  • Remark 2
  • Theorem 2: Retrieval Guarantee
  • Theorem 3: Retrieval Error
  • Lemma 1
  • proof : Proof of Lemma \ref{['Lemma:DistanceBetweenGaussians']}
  • Lemma 2
  • proof : Proof of Lemma \ref{['Lemma:PhiOfXi']}
  • ...and 11 more