Step-resolved data attribution for looped transformers
Georgios Kaissis, David Mildenberger, Juan Felipe Gomez, Martin J. Menten, Eleni Triantafillou
TL;DR
This work tackles the problem of attributing training data influence to internal looped-transformer steps by introducing Step-Decomposed Influence (SDI), which unrolls gradient dynamics across loop iterations to produce a per-step influence trajectory. It generalizes the TracIn estimator to a step-resolved form, and pairs it with a streaming CountSketch/TensorSketch pipeline to compute per-step influence without materialising full per-example gradients. Empirical results on looped GPT-style models and algorithmic tasks demonstrate that SDI matches full-gradient baselines with low error, scales to large models, and yields per-step interpretability of latent reasoning, including discovery of finite-state-like circuits and late-step dominance in reasoning tasks. The framework enables data-centric interpretability and potential data-curation strategies, and points to promising directions for distributing compute and aligning models via step-aware analyses.
Abstract
We study how individual training examples shape the internal computation of looped transformers, where a shared block is applied for $τ$ recurrent iterations to enable latent reasoning. Existing training-data influence estimators such as TracIn yield a single scalar score that aggregates over all loop iterations, obscuring when during the recurrent computation a training example matters. We introduce \textit{Step-Decomposed Influence (SDI)}, which decomposes TracIn into a length-$τ$ influence trajectory by unrolling the recurrent computation graph and attributing influence to specific loop iterations. To make SDI practical at transformer scale, we propose a TensorSketch implementation that never materialises per-example gradients. Experiments on looped GPT-style models and algorithmic reasoning tasks show that SDI scales excellently, matches full-gradient baselines with low error and supports a broad range of data attribution and interpretability tasks with per-step insights into the latent reasoning process.
