Multi-layer Cross-Attention is Provably Optimal for Multi-modal In-context Learning
Nicholas Barnfield, Subhabrata Sen, Pragya Sur
TL;DR
This work studies in-context learning (ICL) for multi-modal data under a tractable latent-factor model where covariate statistics vary across prompts. It proves that single-layer linear self-attention (LSA) cannot achieve Bayes-optimal ICL in this setting, motivating a deep cross-attention (CA) architecture with skip connections and a self-attention (SA) readout. By analyzing a linearized CA model trained via gradient flow, the authors establish that, as depth $T$ grows (and with large test context length $L_{\mathrm{te}}$), the model converges to the Bayes-optimal predictor $\hat{y}_q \to {\boldsymbol{w}}^\top {\boldsymbol{x}}_q$, with an explicit limiting parameter $\alpha^* = \tfrac{2}{2+\underline{m}+\overline{m}}$ and, in a two-parameter variant, $\beta^*\to -\alpha^*$. These results underscore the value of depth and cross-modal attention for ICL under covariate shifts and offer a minimax-type justification for the chosen calibration. Numerical experiments corroborate the theory, showing drastic improvements over LSA and illustrating the necessity of the CA structure and depth. The framework provides theoretical grounding for multi-modal ICL and suggests practical architectural principles for real-world cross-modal learning tasks.
Abstract
Recent progress has rapidly advanced our understanding of the mechanisms underlying in-context learning in modern attention-based neural networks. However, existing results focus exclusively on unimodal data; in contrast, the theoretical underpinnings of in-context learning for multi-modal data remain poorly understood. We introduce a mathematically tractable framework for studying multi-modal learning and explore when transformer-like architectures can recover Bayes-optimal performance in-context. To model multi-modal problems, we assume the observed data arises from a latent factor model. Our first result comprises a negative take on expressibility: we prove that single-layer, linear self-attention fails to recover the Bayes-optimal predictor uniformly over the task distribution. To address this limitation, we introduce a novel, linearized cross-attention mechanism, which we study in the regime where both the number of cross-attention layers and the context length are large. We show that this cross-attention mechanism is provably Bayes optimal when optimized using gradient flow. Our results underscore the benefits of depth for in-context learning and establish the provable utility of cross-attention for multi-modal distributions.
