Table of Contents
Fetching ...

Saddle Hierarchy in Dense Associative Memory

Robin Thériault, Daniele Tantari

TL;DR

This work develops a Dense Associative Memory (DAM) framework built as a three-layer Boltzmann machine with Potts hidden units to study stationary points in both real-data training and teacher-student data generation. The authors derive saddle-point equations that characterize DAMs trained on real data and those trained in a teacher-student setting, then introduce an effective loss with a regularization parameter to stabilize training via β_eff = varsigma(2 upsilon) β. They show a saddle-point hierarchy in which memories learned by narrower DAMs appear as saddles in wider DAMs, enabling a network-growth strategy via splitting steepest descent that significantly reduces training cost, often with runtime scaling near log P_max. Empirically, the DAM learns interpretable prototypes (w^mu) with meaningful soft labels (p^gamma_y) and achieves high classification fidelity, including ~98% accuracy with a 1-NN classifier on memories, while splitting-based training matches accuracy with far fewer computations on MNIST-like data. The results connect energy-based DAMs to modern mechanisms like attention and diffusion constructs, offering a principled path to scalable, interpretable, and robust pattern classification.

Abstract

Dense Associative Memory (DAM) models have been attracting renewed attention since they were shown to be robust to adversarial examples and closely related to cutting edge machine learning paradigms, such as the attention mechanism and generative diffusion. We study a DAM built upon a three-layer Boltzmann machine with Potts hidden units, which represent data clusters and classes. Through a statistical mechanics analysis, we derive saddle-point equations that characterize both the stationary points of DAMs trained on real data and the fixed points of DAMs trained on synthetic data within a teacher-student framework. Based on these results, we propose a novel regularization scheme that makes training significantly more stable. Moreover, we show empirically that our DAM learns interpretable solutions to both supervised and unsupervised classification problems. Pushing our theoretical analysis further, we find that the weights learned by relatively small DAMs correspond to unstable saddle points in larger DAMs. We implement a network-growing algorithm that leverages this saddle-point hierarchy to drastically reduce the computational cost of training dense associative memory.

Saddle Hierarchy in Dense Associative Memory

TL;DR

This work develops a Dense Associative Memory (DAM) framework built as a three-layer Boltzmann machine with Potts hidden units to study stationary points in both real-data training and teacher-student data generation. The authors derive saddle-point equations that characterize DAMs trained on real data and those trained in a teacher-student setting, then introduce an effective loss with a regularization parameter to stabilize training via β_eff = varsigma(2 upsilon) β. They show a saddle-point hierarchy in which memories learned by narrower DAMs appear as saddles in wider DAMs, enabling a network-growth strategy via splitting steepest descent that significantly reduces training cost, often with runtime scaling near log P_max. Empirically, the DAM learns interpretable prototypes (w^mu) with meaningful soft labels (p^gamma_y) and achieves high classification fidelity, including ~98% accuracy with a 1-NN classifier on memories, while splitting-based training matches accuracy with far fewer computations on MNIST-like data. The results connect energy-based DAMs to modern mechanisms like attention and diffusion constructs, offering a principled path to scalable, interpretable, and robust pattern classification.

Abstract

Dense Associative Memory (DAM) models have been attracting renewed attention since they were shown to be robust to adversarial examples and closely related to cutting edge machine learning paradigms, such as the attention mechanism and generative diffusion. We study a DAM built upon a three-layer Boltzmann machine with Potts hidden units, which represent data clusters and classes. Through a statistical mechanics analysis, we derive saddle-point equations that characterize both the stationary points of DAMs trained on real data and the fixed points of DAMs trained on synthetic data within a teacher-student framework. Based on these results, we propose a novel regularization scheme that makes training significantly more stable. Moreover, we show empirically that our DAM learns interpretable solutions to both supervised and unsupervised classification problems. Pushing our theoretical analysis further, we find that the weights learned by relatively small DAMs correspond to unstable saddle points in larger DAMs. We implement a network-growing algorithm that leverages this saddle-point hierarchy to drastically reduce the computational cost of training dense associative memory.

Paper Structure

This paper contains 23 sections, 90 equations, 11 figures, 2 tables, 2 algorithms.

Figures (11)

  • Figure 1: All of the $P = 25$ memories $\left\{ \mathbf{w}^\mu \right\}_{\mu = 1}^{25}$ learned by an instance of our model with $\beta = 16$ when it is trained on the MNIST dataset of handwritten digits lecun1998gradient using constrained stochastic gradient descent (SGD) of the negative log-likelihood loss (Eq. \ref{['eq:loss']}). The hidden units are indexed using pairs of letters from A to E.
  • Figure 2: Illustration of the relationship between $\bar{x}^{\text{fixed}, \mu}_i$ and $\bar{x}^{\text{dupli}, \mu}_i$ stated in Eq. (\ref{['eq:fixed_point']}). The left panel represents the fixed point parameters $\bar{x}^{\text{fixed}, \mu}_i$ of Eq. (\ref{['eq:saddle-point']}) with $P = 10$, and the right panel represents the fixed point parameters $\bar{x}^{\text{dupli}, \mu}_i$ of Eq. (\ref{['eq:saddle-point']}) with $P = 15$. In this example, $R = 5$, which means that the first $5$ entries of $\bar{x}^{\text{fixed}, \mu}_i$ are repeated twice in $\bar{x}^{\text{dupli}, \mu}_i$, while the remaining ones are repeated only once. The first $10$ entries of $\bar{x}^{\text{dupli}, \mu}_i$ are identical to $\bar{x}^{\text{fixed}, \mu}_i$, and the dashed red lines highlights that the first $5$ entries of $\bar{x}^{\text{fixed}, \mu}_i$ are repeated a second time at the end of $\bar{x}^{\text{dupli}, \mu}_i$.
  • Figure 3: $25$ of the $P = 1000$ memories $\mathbf{w}^\mu$ learned by two instances of our dense associative memory (DAM) model with different values of $\beta$. Both networks are trained on the MNIST dataset of handwritten digits lecun1998gradient using constrained stochastic gradient descent (SGD) of the negative log-likelihood loss (Eq. \ref{['eq:loss']}). The left-panel model has $\beta = 18$, and the right-panel one $\beta = 6$. DAMs with $18 > \beta > 6$ learn memories that interpolate between these two pictures. The hidden units are indexed using pairs of letters from A to E.
  • Figure 4: In the top panel, $25$ of the $P = 1000$ memories $\mathbf{w}^\mu$ learned by an instance of our dense associative memory (DAM) model trained on the MNIST dataset of handwritten digits lecun1998gradient using constrained stochastic gradient descent (SGD) of the effective loss (Eq. \ref{['eq:loss']}) with $\varsigma = 0.25$. In the bottom panel, the corresponding rescaled class weights $\mathbf{p}^\mu / p_{\mathbf{h}} \left( \mu \right)$, where $p_{\mathbf{h}} \left( \gamma \right) = \frac{1}{P + 1}$ for all $0 \leq \gamma \leq P$. The hidden units are indexed using pairs of letters from A to E, and the column-wise maxima of the class weights are the classes of the memories with the corresponding letter indices. Rescaled class weights learned with $p_{\mathbf{h}} \left( \gamma \right) \neq \frac{1}{P + 1}$ are qualitatively similar to the ones shown in this figure. Approximately 98% of test digits fed to the DAM are given the class of the memory that resembles them the most. For example, a digit that looks like the memory #AA is given the class 8.
  • Figure 5: In the top panel, $25$ of the $P = 100$ memories $\mathbf{w}^\mu$ learned by an instance of our dense associative memory (DAM) model trained in an unsupervised way (Eq. \ref{['eq:unsupervised_loss']}) on $6 \times 6$ patches of the MNIST dataset of handwritten digits lecun1998gradient while assuming $C = 10$ latent classes and $\varsigma = 0.6$. In the bottom panel, the corresponding rescaled class weights $\mathbf{p}^\mu / p_{\mathbf{h}} \left( \mu \right)$, where $p_{\mathbf{h}} \left( \gamma \right) = \frac{1}{P + 1}$ for all $0 \leq \gamma \leq P$. The hidden units are indexed using pairs of letters from A to E, and the column-wise maxima of the class weights are the classes of the memories with the corresponding letter indices.
  • ...and 6 more figures