Table of Contents
Fetching ...

Mechanisms of AI Protein Folding in ESMFold

Kevin Lu, Jannik Brinkmann, Stefan Huber, Aaron Mueller, Yonatan Belinkov, David Bau, Chris Wendler

TL;DR

This work probes how AI protein folding models derive structure from sequence by dissecting ESMFold's folding trunk. Through activation patching, it identifies two sequential computational stages: early blocks propagate sequence-derived biochemical signals into the pairwise representation, while late blocks refine pairwise geometric features that determine the final coordinates. The study shows that the pairwise representation functions as a distance map and that pair2seq biases mediate geometry-to-sequence communication, with causal interventions such as charge steering and distance steering producing expected structural effects. These findings offer a mechanistic, causal understanding of folding in a state-of-the-art model and suggest generalizable stages across secondary-structure motifs. The work advances interpretability in protein structure prediction by localizing computations within the trunk and demonstrating controllable interventions that influence folding outcomes.

Abstract

How do protein structure prediction models fold proteins? We investigate this question by tracing how ESMFold folds a beta hairpin, a prevalent structural motif. Through counterfactual interventions on model latents, we identify two computational stages in the folding trunk. In the first stage, early blocks initialize pairwise biochemical signals: residue identities and associated biochemical features such as charge flow from sequence representations into pairwise representations. In the second stage, late blocks develop pairwise spatial features: distance and contact information accumulate in the pairwise representation. We demonstrate that the mechanisms underlying structural decisions of ESMFold can be localized, traced through interpretable representations, and manipulated with strong causal effects.

Mechanisms of AI Protein Folding in ESMFold

TL;DR

This work probes how AI protein folding models derive structure from sequence by dissecting ESMFold's folding trunk. Through activation patching, it identifies two sequential computational stages: early blocks propagate sequence-derived biochemical signals into the pairwise representation, while late blocks refine pairwise geometric features that determine the final coordinates. The study shows that the pairwise representation functions as a distance map and that pair2seq biases mediate geometry-to-sequence communication, with causal interventions such as charge steering and distance steering producing expected structural effects. These findings offer a mechanistic, causal understanding of folding in a state-of-the-art model and suggest generalizable stages across secondary-structure motifs. The work advances interpretability in protein structure prediction by localizing computations within the trunk and demonstrating controllable interventions that influence folding outcomes.

Abstract

How do protein structure prediction models fold proteins? We investigate this question by tracing how ESMFold folds a beta hairpin, a prevalent structural motif. Through counterfactual interventions on model latents, we identify two computational stages in the folding trunk. In the first stage, early blocks initialize pairwise biochemical signals: residue identities and associated biochemical features such as charge flow from sequence representations into pairwise representations. In the second stage, late blocks develop pairwise spatial features: distance and contact information accumulate in the pairwise representation. We demonstrate that the mechanisms underlying structural decisions of ESMFold can be localized, traced through interpretable representations, and manipulated with strong causal effects.
Paper Structure (31 sections, 8 equations, 24 figures, 1 table)

This paper contains 31 sections, 8 equations, 24 figures, 1 table.

Figures (24)

  • Figure 1: Two computational stages in the ESMFold folding trunk. We identify which latent representations in the model influence hairpin formation by patching activations from a hairpin protein into a helical protein at each block of the trunk, then measuring whether the output folds as a hairpin. Sequence patches (orange) induce hairpin formation in early blocks (0--7); pairwise patches (green) are effective in late blocks, with results aggregated over 2000 experiments. We show that stage 1 propagates biochemical features (e.g., charge) from sequence into pairwise representations, while stage 2 develops pairwise spatial features (distances, contacts) that modulate sequence attention and control output geometry.
  • Figure 2: Sequence view of a $\beta$-hairpin (highlighted) within a protein: two $\beta$-strand segments separated by a loop. Colored braces label the strand--loop--strand decomposition.
  • Figure 3: 3D cartoon diagram of a beta-hairpin, a common secondary-structure motif consisting of two antiparallel $\beta$-strands (side chain V and side chain W) connected by a short turn/loop. Interactions between side chains on opposing strands and backbone hydrogen bonding together determine the folded hairpin geometry.
  • Figure 4: ESMFold. A protein language model (ESM-2) encodes an amino acid sequence into an initial sequence representation; the pairwise representation is initialized with learned positional embeddings. The folding trunk iteratively refines both representations over $48$ blocks (consisting of multiple layers each). The structure module converts these into 3D coordinates for each residue. Adapted from lin2023evolutionary.
  • Figure 5: Folding block.A sequence update consisting of Pair2Seq and a sequence transformer layer (attention plus MLP) followed by a pairwise update consisting of Seq2Pair and multiplicative and attention based updates of the pairwise representation (triangular update, see jumper2021highly) plus MLP. Adapted from lin2023evolutionary.
  • ...and 19 more figures