Table of Contents
Fetching ...

Multi-layer Cross-Attention is Provably Optimal for Multi-modal In-context Learning

Nicholas Barnfield, Subhabrata Sen, Pragya Sur

TL;DR

This work studies in-context learning (ICL) for multi-modal data under a tractable latent-factor model where covariate statistics vary across prompts. It proves that single-layer linear self-attention (LSA) cannot achieve Bayes-optimal ICL in this setting, motivating a deep cross-attention (CA) architecture with skip connections and a self-attention (SA) readout. By analyzing a linearized CA model trained via gradient flow, the authors establish that, as depth $T$ grows (and with large test context length $L_{\mathrm{te}}$), the model converges to the Bayes-optimal predictor $\hat{y}_q \to {\boldsymbol{w}}^\top {\boldsymbol{x}}_q$, with an explicit limiting parameter $\alpha^* = \tfrac{2}{2+\underline{m}+\overline{m}}$ and, in a two-parameter variant, $\beta^*\to -\alpha^*$. These results underscore the value of depth and cross-modal attention for ICL under covariate shifts and offer a minimax-type justification for the chosen calibration. Numerical experiments corroborate the theory, showing drastic improvements over LSA and illustrating the necessity of the CA structure and depth. The framework provides theoretical grounding for multi-modal ICL and suggests practical architectural principles for real-world cross-modal learning tasks.

Abstract

Recent progress has rapidly advanced our understanding of the mechanisms underlying in-context learning in modern attention-based neural networks. However, existing results focus exclusively on unimodal data; in contrast, the theoretical underpinnings of in-context learning for multi-modal data remain poorly understood. We introduce a mathematically tractable framework for studying multi-modal learning and explore when transformer-like architectures can recover Bayes-optimal performance in-context. To model multi-modal problems, we assume the observed data arises from a latent factor model. Our first result comprises a negative take on expressibility: we prove that single-layer, linear self-attention fails to recover the Bayes-optimal predictor uniformly over the task distribution. To address this limitation, we introduce a novel, linearized cross-attention mechanism, which we study in the regime where both the number of cross-attention layers and the context length are large. We show that this cross-attention mechanism is provably Bayes optimal when optimized using gradient flow. Our results underscore the benefits of depth for in-context learning and establish the provable utility of cross-attention for multi-modal distributions.

Multi-layer Cross-Attention is Provably Optimal for Multi-modal In-context Learning

TL;DR

This work studies in-context learning (ICL) for multi-modal data under a tractable latent-factor model where covariate statistics vary across prompts. It proves that single-layer linear self-attention (LSA) cannot achieve Bayes-optimal ICL in this setting, motivating a deep cross-attention (CA) architecture with skip connections and a self-attention (SA) readout. By analyzing a linearized CA model trained via gradient flow, the authors establish that, as depth grows (and with large test context length ), the model converges to the Bayes-optimal predictor , with an explicit limiting parameter and, in a two-parameter variant, . These results underscore the value of depth and cross-modal attention for ICL under covariate shifts and offer a minimax-type justification for the chosen calibration. Numerical experiments corroborate the theory, showing drastic improvements over LSA and illustrating the necessity of the CA structure and depth. The framework provides theoretical grounding for multi-modal ICL and suggests practical architectural principles for real-world cross-modal learning tasks.

Abstract

Recent progress has rapidly advanced our understanding of the mechanisms underlying in-context learning in modern attention-based neural networks. However, existing results focus exclusively on unimodal data; in contrast, the theoretical underpinnings of in-context learning for multi-modal data remain poorly understood. We introduce a mathematically tractable framework for studying multi-modal learning and explore when transformer-like architectures can recover Bayes-optimal performance in-context. To model multi-modal problems, we assume the observed data arises from a latent factor model. Our first result comprises a negative take on expressibility: we prove that single-layer, linear self-attention fails to recover the Bayes-optimal predictor uniformly over the task distribution. To address this limitation, we introduce a novel, linearized cross-attention mechanism, which we study in the regime where both the number of cross-attention layers and the context length are large. We show that this cross-attention mechanism is provably Bayes optimal when optimized using gradient flow. Our results underscore the benefits of depth for in-context learning and establish the provable utility of cross-attention for multi-modal distributions.
Paper Structure (37 sections, 18 theorems, 206 equations, 6 figures, 1 table)

This paper contains 37 sections, 18 theorems, 206 equations, 6 figures, 1 table.

Key Result

Theorem 4.1

In the setting of Section sec:data and assuming $\|\boldsymbol{m}\|$ is atomlessFor instance, any distribution on $\boldsymbol{m}$ that is absolutely continuous with respect to the Lebesgue measure in ${\mathbb{R}}^d$., no single-layer LSA predictor $\hat{y}_q=\left(\mathsf{LSA}(\boldsymbol{E}_{\bol

Figures (6)

  • Figure 1: Visual representation of the re-injection of covariates $\boldsymbol{X}$ throughout the $\mathsf{CA}$ ($\mathsf{LCA}$) embedding in addition to the skip-connection $\boldsymbol{F}_t$ standard to transformers vaswani2023attentionneed. The propagation of the raw data $\boldsymbol{X}$ occurs both through $\boldsymbol{S}_t$ as well as through the CA block $\boldsymbol{A}_t$.
  • Figure 2: In-context performance at various $L_{\rm te}$ of one- and two-parameter LCA-based models ($T=10$) and LSA model from \ref{['eq:LSA_def']}. Models are optimized on $\ell_{N,L_{\mathrm{tr}}}$ ($L_{\rm tr} = 100, N=2000$) using gradient descent. Error bars represent standard deviation over $1000$ test-prompts.
  • Figure 3: In-context performance for $L_{\rm te} = 64$ at various $T$ of one- and two-parameter LCA models. Models are optimized on $\ell_{N,L_{\mathrm{tr}}}$ ($L_{\rm tr} = 100, N=2000$) using gradient descent. Error bars represent standard deviation over $1000$ test-prompts.
  • Figure 4: In-context performance at various $L_{\rm te}$ of one- and two-parameter LCA-based models, ablations without $S_t$ ($T=10$), and the sample mean $\bar{y}_{L_{\rm te}}$. Models are optimized on $\ell_{N,L_{\mathrm{tr}}}$ ($L_{\rm tr} = 100, N=2000$) using gradient descent. Performance is averaged over $1000$ test-prompts where error bars represent standard deviation over $10$ separate training runs.
  • Figure 5: In-context performance at various $L_{\rm te}$ of one- and two-parameter LCA-based models, two-parameter, deep LSA model ($T=10$), the LSA model from \ref{['eq:LSA_def']}, and the sample mean $\bar{y}_{L_{\rm te}}$. Models are optimized on $\ell_{N,L_{\mathrm{tr}}}$ ($L_{\rm tr} = 100, N=2000$) using gradient descent. Performance is averaged over $1000$ test-prompts where error bars represent standard deviation over $10$ separate training runs.
  • ...and 1 more figures

Theorems & Definitions (38)

  • Theorem 4.1: Single-layer LSA fails at ICL
  • Theorem 6.2: One-parameter model optimality
  • proof : Proof sketch
  • Theorem 6.3: Two-parameter model optimality
  • proof : Proof sketch
  • Lemma A.1: Joint covariance and Bayes coefficient
  • proof
  • Lemma A.2: Matrix LLN at fixed dimension
  • proof
  • proof : Proof of Theorem 3.1
  • ...and 28 more