Table of Contents
Fetching ...

Layer Importance for Mathematical Reasoning is Forged in Pre-Training and Invariant after Post-Training

Aadim Nepal, Safal Shrestha, Anubhav Shrestha, Minwu Kim, Jalal Naghiyev, Ravid Shwartz-Ziv, Keith Ross

TL;DR

We address whether post-training edits to mathematical reasoning reflect broad architectural changes or a small set of layer-level adaptations, using layer-wise zero ablation across two model families (Qwen-2.5-7B and Llama-3.1-8B) in base, Instruct, Distill, and RLVR variants. We show that mathematical reasoning depends on a compact set of critical layers whose identities persist after post-training, with removals causing large accuracy drops (up to eighty percent) while post-training does not alter which layers are critical. Representational analysis with $NMI$ reveals elbow-like reductions in cluster similarity around these critical layers, where token representations shift from surface syntactic groupings to semantically structured relations aligned with problem-solving steps. The findings imply that mathematical competence is forged during pre-training and preserved through post-training, suggesting targeted fine-tuning of a small subset of layers could improve efficiency and interpretability in downstream math tasks.

Abstract

Large language models improve at math after instruction tuning, reinforcement learning, or knowledge distillation. We ask whether these gains come from major changes in the transformer layers or from smaller adjustments that keep the original structure. Using layer-wise ablation on base and trained variants, we find that math reasoning depends on a few critical layers, which stay important across all post-training methods. Removing these layers reduces math accuracy by as much as 80%, whereas factual recall tasks only show relatively smaller drops. This suggests that specialized layers for mathematical tasks form during pre-training and remain stable afterward. As measured by Normalized Mutual Information (NMI), we find that near these critical layers, tokens drift from their original syntactic clusters toward representations aligned with tokens less syntactically related but potentially more useful for downstream task.

Layer Importance for Mathematical Reasoning is Forged in Pre-Training and Invariant after Post-Training

TL;DR

We address whether post-training edits to mathematical reasoning reflect broad architectural changes or a small set of layer-level adaptations, using layer-wise zero ablation across two model families (Qwen-2.5-7B and Llama-3.1-8B) in base, Instruct, Distill, and RLVR variants. We show that mathematical reasoning depends on a compact set of critical layers whose identities persist after post-training, with removals causing large accuracy drops (up to eighty percent) while post-training does not alter which layers are critical. Representational analysis with reveals elbow-like reductions in cluster similarity around these critical layers, where token representations shift from surface syntactic groupings to semantically structured relations aligned with problem-solving steps. The findings imply that mathematical competence is forged during pre-training and preserved through post-training, suggesting targeted fine-tuning of a small subset of layers could improve efficiency and interpretability in downstream math tasks.

Abstract

Large language models improve at math after instruction tuning, reinforcement learning, or knowledge distillation. We ask whether these gains come from major changes in the transformer layers or from smaller adjustments that keep the original structure. Using layer-wise ablation on base and trained variants, we find that math reasoning depends on a few critical layers, which stay important across all post-training methods. Removing these layers reduces math accuracy by as much as 80%, whereas factual recall tasks only show relatively smaller drops. This suggests that specialized layers for mathematical tasks form during pre-training and remain stable afterward. As measured by Normalized Mutual Information (NMI), we find that near these critical layers, tokens drift from their original syntactic clusters toward representations aligned with tokens less syntactically related but potentially more useful for downstream task.

Paper Structure

This paper contains 29 sections, 11 equations, 8 figures.

Figures (8)

  • Figure 1: The plots show model accuracy (Y-axis) on GSM8K and MATH500 when a single transformer layer (X-axis) is zeroed out. The performance of all model variants drops substantially when specific layers are removed (layer 23 for Qwen, layers 15 and 18 for Llama), a pattern that remains consistent across different datasets and post-training methods. Dashed lines indicate the original, un-ablated performance.
  • Figure 2: Layer ablation results on the TriviaQA factual recall task. The left plot shows performance for Qwen 2.5-7B models, and the right plot shows performance for Llama 3.1-8B models when individual layers are zeroed out. The X-axis represents the layer index (0-32), and the Y-axis shows the accuracy.
  • Figure 3: The plots show the NMI score (Y-axis) at each transformer layer (X-axis), calculated relative to the token clusters at Layer 0. The observed trends are robust to the number of clusters (k) used for the analysis, with similar results for k-values between 10 and 70. The choice of 50 here is arbitrary. Shaded region denotes standard deviation over 5 runs. For each run and each model, the 20 problems were selected randomly, so we are looking at over 100 math problems over 8 model families, which is about 800 problems in total.
  • Figure 4: Zero ablation technique illustration. (a) Normal transformer layer with active MLP and attention. (b) Ablated layer with MLP and attention parameters set to zero, effectively nullified due to the skip connection
  • Figure 5: Zero ablation on Qwen2.5-(1.5B,3B,7B)-Instruct Models. We find that all 3 models have critical layers at relatively similar positions.
  • ...and 3 more figures