Table of Contents
Fetching ...

Jet Expansions of Residual Computation

Yihong Chen, Xiangxiang Xu, Yao Lu, Pontus Stenetorp, Luca Franceschi

TL;DR

This work introduces a framework for expanding residual computational graphs using jets, operators that generalize truncated Taylor series and grounds and subsumes logit lens, reveals a (super-)exponential path structure in the recursive residual depth and opens up several applications.

Abstract

We introduce a framework for expanding residual computational graphs using jets, operators that generalize truncated Taylor series. Our method provides a systematic approach to disentangle contributions of different computational paths to model predictions. In contrast to existing techniques such as distillation, probing, or early decoding, our expansions rely solely on the model itself and requires no data, training, or sampling from the model. We demonstrate how our framework grounds and subsumes logit lens, reveals a (super-)exponential path structure in the recursive residual depth and opens up several applications. These include sketching a transformer large language model with $n$-gram statistics extracted from its computations, and indexing the models' levels of toxicity knowledge. Our approach enables data-free analysis of residual computation for model interpretability, development, and evaluation.

Jet Expansions of Residual Computation

TL;DR

This work introduces a framework for expanding residual computational graphs using jets, operators that generalize truncated Taylor series and grounds and subsumes logit lens, reveals a (super-)exponential path structure in the recursive residual depth and opens up several applications.

Abstract

We introduce a framework for expanding residual computational graphs using jets, operators that generalize truncated Taylor series. Our method provides a systematic approach to disentangle contributions of different computational paths to model predictions. In contrast to existing techniques such as distillation, probing, or early decoding, our expansions rely solely on the model itself and requires no data, training, or sampling from the model. We demonstrate how our framework grounds and subsumes logit lens, reveals a (super-)exponential path structure in the recursive residual depth and opens up several applications. These include sketching a transformer large language model with -gram statistics extracted from its computations, and indexing the models' levels of toxicity knowledge. Our approach enables data-free analysis of residual computation for model interpretability, development, and evaluation.
Paper Structure (35 sections, 2 theorems, 15 equations, 18 figures, 5 tables, 2 algorithms)

This paper contains 35 sections, 2 theorems, 15 equations, 18 figures, 5 tables, 2 algorithms.

Key Result

Lemma 1

Let $f\in C^{\infty}(\mathbb{R}^d, \mathbb{R}^d)$, $k\in\mathbb{N}, N\in\mathbb{N}^+$, $\{\textcolor{red}{x_i}\}_{i\in[N]}$ be a set of jet centers, $w\in\triangle^{N-1} \subset \mathbb{R}^N$ be a set of jet weights, and $r=\max_i \{ w_i\| x_i - \sum_{j} x_j\| \}$. Then

Figures (18)

  • Figure 1: Various equivalent representations of a two-blocks linear residual network. In particular (b) highlights the residual stream of \ref{['eq:unrolling']}; (d) highlights the exponential rewriting of \ref{['eq:lin-nets']}.
  • Figure 2: Representation of a two-blocks residual net (a, a-bis) and its exponential expansion steps (b, c).
  • Figure 3: (Top) example of a joint jet lens on GPT-Neo $2.7$B with $k=1$, visualizing the seven blocks with highest average jet weights after optimization. Each table cell indicates the most likely token of the jet path related to each block non-linearity. Optimized jet weight are in brackets. We used a diverging blue-to-red color map tracking logit scores, centered around zero. The bottom table shows the model logits and the expansion logits, with cosine similarity in brackets; in this case, all top-$1$ tokens perfectly coincide. (Bottom) plots of average cosine similarities between original and jet logits of joint (left) and iterative (right) lenses.
  • Figure 4: Analysis of OLMo-$7$B's pretraining dynamics via measuring its jet bi-gram progression.
  • Figure 5: Visualization of OLMo-$7$B's promotion and suppression dynamics of jet bi-grams scores.
  • ...and 13 more figures

Theorems & Definitions (5)

  • Lemma 1: Convex combinations of jets
  • Remark 1: Jet centers and variates as functions
  • Remark 2: Non-vanishing remainders
  • Remark 3: Jet weights optimization
  • Proposition 1: Jet algebra