Table of Contents
Fetching ...

Learning to Remember, Learn, and Forget in Attention-Based Models

Djohan Bonnet, Jamie Lohoff, Jan Finkbeiner, Elidona Skhikerujah, Emre Neftci

TL;DR

This work treats In-Context Learning in transformers as a continual-learning problem with fixed-size memories and interference risks. It proposes Palimpsa, a Bayesian metaplasticity-based attention mechanism that adapts the plasticity of each memory state via a per-state importance (I_t) and a forgetting gate tied to a memory window N_t, enabling both forgetting and preserving critical past information. The authors derive Palimpsa from a variational Bayesian objective, show that Mamba2 is a special case of Palimpsa, and provide a continuum that allows metaplastic finetuning of pre-trained models. Empirically, Palimpsa improves performance on the MQAR benchmark and on Commonsense Reasoning tasks, with larger gains as sequence length grows and with fine-tuning at scale, highlighting practical memory improvements for edge-friendly, fixed-memory transformers.

Abstract

In-Context Learning (ICL) in transformers acts as an online associative memory and is believed to underpin their high performance on complex sequence processing tasks. However, in gated linear attention models, this memory has a fixed capacity and is prone to interference, especially for long sequences. We propose Palimpsa, a self-attention model that views ICL as a continual learning problem that must address a stability-plasticity dilemma. Palimpsa uses Bayesian metaplasticity, where the plasticity of each attention state is tied to an importance state grounded by a prior distribution that captures accumulated knowledge. We demonstrate that various gated linear attention models emerge as specific architecture choices and posterior approximations, and that Mamba2 is a special case of Palimpsa where forgetting dominates. This theoretical link enables the transformation of any non-metaplastic model into a metaplastic one, significantly expanding its memory capacity. Our experiments show that Palimpsa consistently outperforms baselines on the Multi-Query Associative Recall (MQAR) benchmark and on Commonsense Reasoning tasks.

Learning to Remember, Learn, and Forget in Attention-Based Models

TL;DR

This work treats In-Context Learning in transformers as a continual-learning problem with fixed-size memories and interference risks. It proposes Palimpsa, a Bayesian metaplasticity-based attention mechanism that adapts the plasticity of each memory state via a per-state importance (I_t) and a forgetting gate tied to a memory window N_t, enabling both forgetting and preserving critical past information. The authors derive Palimpsa from a variational Bayesian objective, show that Mamba2 is a special case of Palimpsa, and provide a continuum that allows metaplastic finetuning of pre-trained models. Empirically, Palimpsa improves performance on the MQAR benchmark and on Commonsense Reasoning tasks, with larger gains as sequence length grows and with fine-tuning at scale, highlighting practical memory improvements for edge-friendly, fixed-memory transformers.

Abstract

In-Context Learning (ICL) in transformers acts as an online associative memory and is believed to underpin their high performance on complex sequence processing tasks. However, in gated linear attention models, this memory has a fixed capacity and is prone to interference, especially for long sequences. We propose Palimpsa, a self-attention model that views ICL as a continual learning problem that must address a stability-plasticity dilemma. Palimpsa uses Bayesian metaplasticity, where the plasticity of each attention state is tied to an importance state grounded by a prior distribution that captures accumulated knowledge. We demonstrate that various gated linear attention models emerge as specific architecture choices and posterior approximations, and that Mamba2 is a special case of Palimpsa where forgetting dominates. This theoretical link enables the transformation of any non-metaplastic model into a metaplastic one, significantly expanding its memory capacity. Our experiments show that Palimpsa consistently outperforms baselines on the Multi-Query Associative Recall (MQAR) benchmark and on Commonsense Reasoning tasks.
Paper Structure (27 sections, 50 equations, 5 figures, 4 tables)

This paper contains 27 sections, 50 equations, 5 figures, 4 tables.

Figures (5)

  • Figure 1: Bayesian Metaplasticity Attention. Self-attention in autoregressive transformers is inherently a continual learning problem, and as such can suffer from catastrophic forgetting. Metaplasticity dynamically modifies the learning rate to preserve important prior information. (Bottom-left) Illustration of Bayesian metalearning: $q_{\theta_t}$ is the (variational) distribution over memory states $\bm{S}$ at time step $t$. (Right) We derive Palimpsa, a new attention block based on an online Bayesian posterior, preventing both catastrophic forgetting and remembering using metaplasticity.
  • Figure 2: Curriculum MQAR experiments. Accuracy averaged over 8 seeds for the best learning rate per model. Individual run accuracies are shown as black dots; error bars represent $\pm 1$ standard deviation. Task difficulty increases with sequence length $L$. "w/o Meta" indicates that metaplasticity was disabled for those models.
  • Figure 3: Palimpsa's Learning Dynamics: Memory window $N_t$ (blue), averaged over the context length, and the metaplasticity ratio (orange), defined on the final state importance as $(I_{\max}-I_{\min})/I_{\min}$, and the training loss (pink). A higher ratio indicates stronger differentiation between plastic and consolidated synapses. Shaded regions represent the standard deviation over 8 seeds.
  • Figure 4: Illustration of the Palimpsa-M and Palimpsa-D architectures. While Palimpsa-M adopts the Mamba-2 configuration by relying on the attention layer for channel mixing, Palimpsa-D incorporates an explicit gated MLP. Additionally, Palimpsa-D introduces a dedicated $b_t$ parameter to decorrelate input integration from the forgetting dynamics dictated by $d_t$.
  • Figure 5: Inference Speed Benchmark. Throughput (thousands of tokens/s) comparison between Palimpsa and Simple GLA on an NVIDIA GeForce RTX 3090. Palimpsa matches the baseline's scaling behavior while maintaining a consistent $4\times$ factor due to the dual-state update overhead.