Exponential Family Attention
Kevin Christian Wibisono, Yixin Wang
TL;DR
Exponential Family Attention (EFA) generalizes self-attention from language to high-dimensional, mixed-type data by coupling attention-derived context with exponential-family conditionals. The approach unifies latent-factor models as special cases while enabling nonlinear, context-dependent interactions and learned context sets through attention. Theoretical contributions include linear identifiability and an excess loss generalization bound, complemented by strong empirical performance on synthetic data, Instacart baskets, MovieLens, and spatiotemporal temperatures. The results suggest EFA’s broad applicability for modeling complex dependencies in non-text domains and its potential to improve predictive reconstructions and recommendations in real-world settings.
Abstract
The self-attention mechanism is the backbone of the transformer neural network underlying most large language models. It can capture complex word patterns and long-range dependencies in natural language. This paper introduces exponential family attention (EFA), a probabilistic generative model that extends self-attention to handle high-dimensional sequence, spatial, or spatial-temporal data of mixed data types, including both discrete and continuous observations. The key idea of EFA is to model each observation conditional on all other existing observations, called the context, whose relevance is learned in a data-driven way via an attention-based latent factor model. In particular, unlike static latent embeddings, EFA uses the self-attention mechanism to capture dynamic interactions in the context, where the relevance of each context observations depends on other observations. We establish an identifiability result and provide a generalization guarantee on excess loss for EFA. Across real-world and synthetic data sets -- including U.S. city temperatures, Instacart shopping baskets, and MovieLens ratings -- we find that EFA consistently outperforms existing models in capturing complex latent structures and reconstructing held-out data.
