Table of Contents
Fetching ...

Step-resolved data attribution for looped transformers

Georgios Kaissis, David Mildenberger, Juan Felipe Gomez, Martin J. Menten, Eleni Triantafillou

TL;DR

This work tackles the problem of attributing training data influence to internal looped-transformer steps by introducing Step-Decomposed Influence (SDI), which unrolls gradient dynamics across loop iterations to produce a per-step influence trajectory. It generalizes the TracIn estimator to a step-resolved form, and pairs it with a streaming CountSketch/TensorSketch pipeline to compute per-step influence without materialising full per-example gradients. Empirical results on looped GPT-style models and algorithmic tasks demonstrate that SDI matches full-gradient baselines with low error, scales to large models, and yields per-step interpretability of latent reasoning, including discovery of finite-state-like circuits and late-step dominance in reasoning tasks. The framework enables data-centric interpretability and potential data-curation strategies, and points to promising directions for distributing compute and aligning models via step-aware analyses.

Abstract

We study how individual training examples shape the internal computation of looped transformers, where a shared block is applied for $τ$ recurrent iterations to enable latent reasoning. Existing training-data influence estimators such as TracIn yield a single scalar score that aggregates over all loop iterations, obscuring when during the recurrent computation a training example matters. We introduce \textit{Step-Decomposed Influence (SDI)}, which decomposes TracIn into a length-$τ$ influence trajectory by unrolling the recurrent computation graph and attributing influence to specific loop iterations. To make SDI practical at transformer scale, we propose a TensorSketch implementation that never materialises per-example gradients. Experiments on looped GPT-style models and algorithmic reasoning tasks show that SDI scales excellently, matches full-gradient baselines with low error and supports a broad range of data attribution and interpretability tasks with per-step insights into the latent reasoning process.

Step-resolved data attribution for looped transformers

TL;DR

This work tackles the problem of attributing training data influence to internal looped-transformer steps by introducing Step-Decomposed Influence (SDI), which unrolls gradient dynamics across loop iterations to produce a per-step influence trajectory. It generalizes the TracIn estimator to a step-resolved form, and pairs it with a streaming CountSketch/TensorSketch pipeline to compute per-step influence without materialising full per-example gradients. Empirical results on looped GPT-style models and algorithmic tasks demonstrate that SDI matches full-gradient baselines with low error, scales to large models, and yields per-step interpretability of latent reasoning, including discovery of finite-state-like circuits and late-step dominance in reasoning tasks. The framework enables data-centric interpretability and potential data-curation strategies, and points to promising directions for distributing compute and aligning models via step-aware analyses.

Abstract

We study how individual training examples shape the internal computation of looped transformers, where a shared block is applied for recurrent iterations to enable latent reasoning. Existing training-data influence estimators such as TracIn yield a single scalar score that aggregates over all loop iterations, obscuring when during the recurrent computation a training example matters. We introduce \textit{Step-Decomposed Influence (SDI)}, which decomposes TracIn into a length- influence trajectory by unrolling the recurrent computation graph and attributing influence to specific loop iterations. To make SDI practical at transformer scale, we propose a TensorSketch implementation that never materialises per-example gradients. Experiments on looped GPT-style models and algorithmic reasoning tasks show that SDI scales excellently, matches full-gradient baselines with low error and supports a broad range of data attribution and interpretability tasks with per-step insights into the latent reasoning process.
Paper Structure (16 sections, 6 theorems, 64 equations, 4 figures, 1 table, 1 algorithm)

This paper contains 16 sections, 6 theorems, 64 equations, 4 figures, 1 table, 1 algorithm.

Key Result

Proposition 1

For a looped transformer with $\tau$ steps and sequence length $L$, let $\mathbf{h}_{t,j} \in \mathbb{R}^d$ denote the hidden state of the $j$-th token at step $t$. The total derivative of the loss with respect to $\mathbf{w}_\text{body}$ can be unrolled into a sum over $\tau$ steps and $L$ tokens: Where $\frac{\mathrm{d}\ell}{\mathrm{d}\mathbf{h}_{t,j}} \in \mathbb{R}^d$ is the gradient at step

Figures (4)

  • Figure 1: SDI as a mechanistic discovery tool We train a looped transformer on parity and apply SDI to an alternating probe $(0101\dots)$. (A)SDI and Logit Margin at the answer token across loop iterations $t$; light gray guides indicate the SDI peak phase (period $4$) and the dotted gray line at $t{=}40$ marks the model's readout step. (B) PCA phase portrait of the answer-token hidden state across loop iterations, coloured by a $k{=}4$ discretization, revealing a four-state limit cycle.
  • Figure 2: Scaling laws of loop compute on SATNet Sudoku.(A) Test board accuracy (a board is correct iff all blank cells are correct) versus the number of test-time loops, stratified by puzzle difficulty (binned by the number of initial missing cells; more missing implies harder). Harder puzzles are substantially more compute-sensitive: reducing loops sharply degrades accuracy, while increasing loops yields gains up to about the training-mean depth $\tau{\approx}32$ (dotted line) before saturating; we include ${>}32$ loops to show the plateau. (B) SDI energy across loop steps (median and IQR across puzzles in each difficulty bin), where SDI energy at step $t$ is the sum of absolute stepwise SDI scores across training points. Harder puzzles maintain higher SDI energy deeper into the recurrence (slower decay), mirroring their larger marginal gains from additional loop compute.
  • Figure 3: Geometric influence growth across the loop horizon. Signed SDI share per step $t$ for several analysis horizons $\tau$ in mid-training stage, recomputed with full BPTT through the recurrence from GSM8K train to GSM8K test. Removing truncation reveals a smooth, approximately geometric increase of stepwise influence with depth.
  • Figure 4: Empirical error scaling w.r.t sketch dimension $m$. Average and standard deviation over 10 trials.

Theorems & Definitions (17)

  • Proposition 1
  • proof
  • Definition 1: Step-Decomposed Influence (SDI)
  • Lemma 1
  • proof
  • Definition 2
  • Definition 3: pham2025tensor
  • Lemma 2: pham2025tensor
  • proof
  • Remark 1
  • ...and 7 more