Table of Contents
Fetching ...

B'MOJO: Hybrid State Space Realizations of Foundation Models with Eidetic and Fading Memory

Luca Zancato, Arjun Seshadri, Yonatan Dukler, Aditya Golatkar, Yantao Shen, Benjamin Bowman, Matthew Trager, Alessandro Achille, Stefano Soatto

TL;DR

B'MOJO addresses transductive inference under memory constraints by unifying fading memory (state-space based) with an externally growing eidetic memory via Innovation Selection, within a Stochastic Realization framework. The model family generalizes Transformers, Mamba, and Jamba, achieving efficient computation through chunked, interleaved processing and a fixed-predictor retrieval mechanism. Empirically, B'MOJO and its fading-eidetic variant outperform fading-only SSMs and hybrid baselines on associative recall tasks, attain perplexities comparable to similarly-sized Transformers up to 1.4B parameters with up to 10% faster training, and demonstrate robust length generalization up to 32K tokens when pre-trained on longer contexts. The work suggests a scalable, hardware-friendly approach to memory-aware foundation models, while acknowledging scaling and societal considerations and outlining directions for further improvement in long-context inference and controllability.

Abstract

We describe a family of architectures to support transductive inference by allowing memory to grow to a finite but a-priori unknown bound while making efficient use of finite resources for inference. Current architectures use such resources to represent data either eidetically over a finite span ("context" in Transformers), or fading over an infinite span (in State Space Models, or SSMs). Recent hybrid architectures have combined eidetic and fading memory, but with limitations that do not allow the designer or the learning process to seamlessly modulate the two, nor to extend the eidetic memory span. We leverage ideas from Stochastic Realization Theory to develop a class of models called B'MOJO to seamlessly combine eidetic and fading memory within an elementary composable module. The overall architecture can be used to implement models that can access short-term eidetic memory "in-context," permanent structural memory "in-weights," fading memory "in-state," and long-term eidetic memory "in-storage" by natively incorporating retrieval from an asynchronously updated memory. We show that Transformers, existing SSMs such as Mamba, and hybrid architectures such as Jamba are special cases of B'MOJO and describe a basic implementation, to be open sourced, that can be stacked and scaled efficiently in hardware. We test B'MOJO on transductive inference tasks, such as associative recall, where it outperforms existing SSMs and Hybrid models; as a baseline, we test ordinary language modeling where B'MOJO achieves perplexity comparable to similarly-sized Transformers and SSMs up to 1.4B parameters, while being up to 10% faster to train. Finally, we show that B'MOJO's ability to modulate eidetic and fading memory results in better inference on longer sequences tested up to 32K tokens, four-fold the length of the longest sequences seen during training.

B'MOJO: Hybrid State Space Realizations of Foundation Models with Eidetic and Fading Memory

TL;DR

B'MOJO addresses transductive inference under memory constraints by unifying fading memory (state-space based) with an externally growing eidetic memory via Innovation Selection, within a Stochastic Realization framework. The model family generalizes Transformers, Mamba, and Jamba, achieving efficient computation through chunked, interleaved processing and a fixed-predictor retrieval mechanism. Empirically, B'MOJO and its fading-eidetic variant outperform fading-only SSMs and hybrid baselines on associative recall tasks, attain perplexities comparable to similarly-sized Transformers up to 1.4B parameters with up to 10% faster training, and demonstrate robust length generalization up to 32K tokens when pre-trained on longer contexts. The work suggests a scalable, hardware-friendly approach to memory-aware foundation models, while acknowledging scaling and societal considerations and outlining directions for further improvement in long-context inference and controllability.

Abstract

We describe a family of architectures to support transductive inference by allowing memory to grow to a finite but a-priori unknown bound while making efficient use of finite resources for inference. Current architectures use such resources to represent data either eidetically over a finite span ("context" in Transformers), or fading over an infinite span (in State Space Models, or SSMs). Recent hybrid architectures have combined eidetic and fading memory, but with limitations that do not allow the designer or the learning process to seamlessly modulate the two, nor to extend the eidetic memory span. We leverage ideas from Stochastic Realization Theory to develop a class of models called B'MOJO to seamlessly combine eidetic and fading memory within an elementary composable module. The overall architecture can be used to implement models that can access short-term eidetic memory "in-context," permanent structural memory "in-weights," fading memory "in-state," and long-term eidetic memory "in-storage" by natively incorporating retrieval from an asynchronously updated memory. We show that Transformers, existing SSMs such as Mamba, and hybrid architectures such as Jamba are special cases of B'MOJO and describe a basic implementation, to be open sourced, that can be stacked and scaled efficiently in hardware. We test B'MOJO on transductive inference tasks, such as associative recall, where it outperforms existing SSMs and Hybrid models; as a baseline, we test ordinary language modeling where B'MOJO achieves perplexity comparable to similarly-sized Transformers and SSMs up to 1.4B parameters, while being up to 10% faster to train. Finally, we show that B'MOJO's ability to modulate eidetic and fading memory results in better inference on longer sequences tested up to 32K tokens, four-fold the length of the longest sequences seen during training.
Paper Structure (29 sections, 26 equations, 10 figures, 3 tables, 1 algorithm)

This paper contains 29 sections, 26 equations, 10 figures, 3 tables, 1 algorithm.

Figures (10)

  • Figure 1: B'MOJO's memory management.(Left) Illustration of the B'MOJO layer.(Right) B'MOJO's Realization. B'MOJO's fading memory is computed by a SSM that represents long-range dependencies through its state (a fixed-dimensional representation) which is later aggregated along with with the most recent past. B'MOJO's eidetic memory stores tokens selected from the past using an innovation test on the SSM's state and appends them to the current sliding window. The innovation test measures how difficult it is to predict the next token using the SSM's state. If a token is difficult to predict, we store it in the eidetic memory and pass it to the attention module together with the state, a compressed summary of the past, and the most recent tokens.
  • Figure 2: (Panels 1-3) B'MOJO has high memory efficiency on Associative Recall Tasks. For various models, we plot accuracy on the Multi-Query Associative Recall (MQAR) task as a function of the model dimension (totaling the SSM state, eidetic memory and KV cache where applicable). The transformer paragon attains 100% accuracy because it operates on the full context. While all models benefit strongly from increased memory, B'MOJO and B'MOJO-F consistently achieve the best accuracies for a given memory budget. Panels 1-3 report MQAR tasks of increasing difficulty, on which the performance gap between B'MOJO and other models increases, showcasing the value of eidetic memory. (Panel 4) Marginal increases in eidetic memory size corresponds to gains in recall. We probe into the role eidetic Memory plays in recall by growing the number of eidetic memory tokens in B'MOJO. Each added token contributes to an increase in recall accuracy, until the dimension increases and saturates gains from additional memory.
  • Figure 3: B'MOJO language modeling scaling laws. We plot the perplexity reached by models at different scales against the number of parameters and the wall-clock training time. B'MOJO is faster than Mamba and Mistral at training time while achieving better perplexity than Mamba and comparable perplexity with Mistral. The plot also exhibits a non-saturating scaling law, showing that increasing the amount of resources leads to increasingly better B'MOJO models.
  • Figure 4: B'MOJO's throughput. Time in ms to process 2k sequences for different batch sizes and model dimensions. B'MOJO is faster than other efficient implementations of Mamba gu2023mamba and Transformers gu2023mamba at all scales.
  • Figure 5: Length generalization.(Left) We pre-train B'MOJO 1.4B and Mamba 1.4B on 2k context lengths and a 1.4B Transformer baseline on 1k. (Right) We pre-train B'MOJO 790M and Mamba 790M on 8k context length and compare models on length generalization evaluating perplexity on longer sequences of up to 32k tokens using the PG-19 dataset. Transformers cannot length generalize (a known failure mode), on the other hand B'MOJO preserves/improves in perplexity better than Mamba even on longer sequences. We also observe that increasing the pre-training context length (from 2k to 8k) leads to better length generalization results. This showcases a better use of the information from the remote past thanks to the use of fading and eidetic memory.
  • ...and 5 more figures

Theorems & Definitions (4)

  • Example A.1: Biology
  • Example A.2: CNN Classifiers, VAEs and GANs
  • Example A.3: Diffusion Models
  • Example A.4: The Sage and the Savant