Table of Contents
Fetching ...

Distributional Associations vs In-Context Reasoning: A Study of Feed-forward and Attention Layers

Lei Chen, Joan Bruna, Alberto Bietti

TL;DR

The paper addresses how Transformer architectures distribute distributional knowledge versus in-context reasoning across feed-forward and attention layers. By constructing a synthetic two-layer Transformer with noisy in-context recall, it shows feed-forward layers capture simple distributional cues (e.g., bigrams) while attention executes in-context reasoning, with gradient dynamics explaining the separation. The authors provide theoretical insights into training dynamics and demonstrate that targeted weight truncation (LASER) shifts pre-trained models toward better reasoning on IOI and factual recall tasks, and can enhance few-shot GSM8K performance. These findings offer a principled lens for understanding and steering layer-specific roles in LLMs, informing more effective fine-tuning and architecture design for reasoning-heavy tasks.

Abstract

Large language models have been successful at tasks involving basic forms of in-context reasoning, such as generating coherent language, as well as storing vast amounts of knowledge. At the core of the Transformer architecture behind such models are feed-forward and attention layers, which are often associated to knowledge and reasoning, respectively. In this paper, we study this distinction empirically and theoretically in a controlled synthetic setting where certain next-token predictions involve both distributional and in-context information. We find that feed-forward layers tend to learn simple distributional associations such as bigrams, while attention layers focus on in-context reasoning. Our theoretical analysis identifies the noise in the gradients as a key factor behind this discrepancy. Finally, we illustrate how similar disparities emerge in pre-trained models through ablations on the Pythia model family on simple reasoning tasks.

Distributional Associations vs In-Context Reasoning: A Study of Feed-forward and Attention Layers

TL;DR

The paper addresses how Transformer architectures distribute distributional knowledge versus in-context reasoning across feed-forward and attention layers. By constructing a synthetic two-layer Transformer with noisy in-context recall, it shows feed-forward layers capture simple distributional cues (e.g., bigrams) while attention executes in-context reasoning, with gradient dynamics explaining the separation. The authors provide theoretical insights into training dynamics and demonstrate that targeted weight truncation (LASER) shifts pre-trained models toward better reasoning on IOI and factual recall tasks, and can enhance few-shot GSM8K performance. These findings offer a principled lens for understanding and steering layer-specific roles in LLMs, informing more effective fine-tuning and architecture design for reasoning-heavy tasks.

Abstract

Large language models have been successful at tasks involving basic forms of in-context reasoning, such as generating coherent language, as well as storing vast amounts of knowledge. At the core of the Transformer architecture behind such models are feed-forward and attention layers, which are often associated to knowledge and reasoning, respectively. In this paper, we study this distinction empirically and theoretically in a controlled synthetic setting where certain next-token predictions involve both distributional and in-context information. We find that feed-forward layers tend to learn simple distributional associations such as bigrams, while attention layers focus on in-context reasoning. Our theoretical analysis identifies the noise in the gradients as a key factor behind this discrepancy. Finally, we illustrate how similar disparities emerge in pre-trained models through ablations on the Pythia model family on simple reasoning tasks.
Paper Structure (41 sections, 17 theorems, 178 equations, 20 figures, 5 tables)

This paper contains 41 sections, 17 theorems, 178 equations, 20 figures, 5 tables.

Key Result

Theorem 1

Assume $N, T\gg 1, \alpha=\Theta(1)$. For the model in Eq(eq:simplifed_model_archi), consider one gradient step update from zero-initialization on $m$ i.i.d. samples of $z_{1:T}$ with separate learning rates $\eta_f$ for $\mathbf{W}_F$ and $\eta_v$ for $\mathbf{W}_V$ (note that the gradient on $\mat where $\Delta(\xi) = \xi_{N+1} - \max_{j\in [N]} \xi_j$ is the margin of predicting the generic noi

Figures (20)

  • Figure 1: Distributional association v.s. in-context reasoning. In this work, we decompose tasks of next-token prediction into the distributional and the in-context ones, finding that MLPs learn distributional associations before attention develops in-context reasoning capabilities. Furthermore, truncating MLPs promotes in-context reasoning by weakening distributional associations. See Figure \ref{['fig:pythia-ioi-factual']} for an example of this on the Pythia model biderman2023pythia.
  • Figure 2: Noisy in-context recall. Purpose of design: understand mechanisms of attention and feed-forward layers for tasks with in-context reasoning (predict $\bar{y}$) and distributional association (predict $\tau$). Task: predict tokens $\bar{y}$v.s.$\tau$ from a sentence $[\dots,q,\bar{y},\dots,q,\tau,\dots,q]$ where $q$ is trigger, $\bar{y}$ is sampled target token for a sentence, and $\tau$ is a fixed generic token across sentences. Our findings: in a two-layer transformer, the second-layer attention (Attn-2) only attends towards target tuples $[q,\bar{y}]$ while the feed-forward layer (FF-2) learns to predict $\tau$.
  • Figure 3: Left three: Average probability of predicting correct and noise tokens, and test loss on clean data ($\alpha = 0$), with different fractions $\rho$ of preserved rank in $U_{in}$ of the second-layer MLP $F_2$. The full model learns to predict noise with probability around $\alpha=0.5$, as expected from training data. When $F_2$ is dropped ($\rho=0$), the model predicts the correct token $\bar{y}$ with probability $\approx 0.98$. Rightmost: the FF-2 margin of $\tau$v.s. all the other tokens with input as $q$, i.e., $[\mathbf{W}_U F_2(\mathbf{W}_E(q))]_{\tau} - \max_{k\le N}[\mathbf{W}_U F_2(\mathbf{W}_E(q))]_k$. It reveals that FF-2 learns trigger-noise association in early steps.
  • Figure 4: The second-layer attention scores of models trained with noise (left), fine-tuned with noise (right, initialized as a model pre-trained without noise), given the same input. It turns out both models learn to attend to the informative structure "[trigger]+$\bar{y}$" instead of "[trigger]+noise". This implies that the attention in these models is only responsible to predict $\bar{y}$, although the training input and output have noise with probability $\alpha=\Theta(1)$. The fine-tuning setting is in Appendix \ref{['app:finetune-setting']}.
  • Figure 5: Left: average probability of tokens [IO], [S] and "the" in 100-sentence IOI task in the prediction by Pythia-1B along training. Right: average probability of tokens "Spain" and "the" in a factual task predicted by Pythia-1B along training, with input as "Madrid is located in". In both tasks, the full model learns to predict "the" with high probability starting from $\sim$10 steps, and then learns to solve the tasks. LASER boosts the probability of correct answers against "the" in both tasks: the average probability ratio of correct answers against "the" improves from 2.3$\times$ to 12.3$\times$ (in IOI) and from 0.16$\times$ to 11.3$\times$ (in factual) at 14K steps.
  • ...and 15 more figures

Theorems & Definitions (33)

  • Theorem 1: Logits after one gradient step
  • Theorem 2: Attention attends to in-context targets
  • Theorem 3
  • Lemma D.1
  • proof
  • Lemma D.2
  • proof
  • Theorem 4: Restatement of Theorem \ref{['thm:two-layer:one-step-WF-WV']}
  • proof
  • Lemma E.1: $\bar{y} = q, k=q$
  • ...and 23 more