Table of Contents
Fetching ...

Reverse Distillation: Consistently Scaling Protein Language Model Representations

Darius Catrina, Christian Bepler, Samuel Sledzieski, Rohit Singh

TL;DR

Reverse Distillation is introduced, a principled framework that decomposes large PLM representations into orthogonal subspaces guided by smaller models of the same family, ensuring that larger reverse-distilled models consistently outperform smaller ones.

Abstract

Unlike the predictable scaling laws in natural language processing and computer vision, protein language models (PLMs) scale poorly: for many tasks, models within the same family plateau or even decrease in performance, with mid-sized models often outperforming the largest in the family. We introduce Reverse Distillation, a principled framework that decomposes large PLM representations into orthogonal subspaces guided by smaller models of the same family. The resulting embeddings have a nested, Matryoshka-style structure: the first k dimensions of a larger model's embedding are exactly the representation from the smaller model. This ensures that larger reverse-distilled models consistently outperform smaller ones. A motivating intuition is that smaller models, constrained by capacity, preferentially encode broadly-shared protein features. Reverse distillation isolates these shared features and orthogonally extracts additional contributions from larger models, preventing interference between the two. On ProteinGym benchmarks, reverse-distilled ESM-2 variants outperform their respective baselines at the same embedding dimensionality, with the reverse-distilled 15 billion parameter model achieving the strongest performance. Our framework is generalizable to any model family where scaling challenges persist. Code and trained models are available at https://github.com/rohitsinghlab/plm_reverse_distillation.

Reverse Distillation: Consistently Scaling Protein Language Model Representations

TL;DR

Reverse Distillation is introduced, a principled framework that decomposes large PLM representations into orthogonal subspaces guided by smaller models of the same family, ensuring that larger reverse-distilled models consistently outperform smaller ones.

Abstract

Unlike the predictable scaling laws in natural language processing and computer vision, protein language models (PLMs) scale poorly: for many tasks, models within the same family plateau or even decrease in performance, with mid-sized models often outperforming the largest in the family. We introduce Reverse Distillation, a principled framework that decomposes large PLM representations into orthogonal subspaces guided by smaller models of the same family. The resulting embeddings have a nested, Matryoshka-style structure: the first k dimensions of a larger model's embedding are exactly the representation from the smaller model. This ensures that larger reverse-distilled models consistently outperform smaller ones. A motivating intuition is that smaller models, constrained by capacity, preferentially encode broadly-shared protein features. Reverse distillation isolates these shared features and orthogonally extracts additional contributions from larger models, preventing interference between the two. On ProteinGym benchmarks, reverse-distilled ESM-2 variants outperform their respective baselines at the same embedding dimensionality, with the reverse-distilled 15 billion parameter model achieving the strongest performance. Our framework is generalizable to any model family where scaling challenges persist. Code and trained models are available at https://github.com/rohitsinghlab/plm_reverse_distillation.
Paper Structure (24 sections, 1 theorem, 5 equations, 4 figures, 4 tables, 3 algorithms)

This paper contains 24 sections, 1 theorem, 5 equations, 4 figures, 4 tables, 3 algorithms.

Key Result

Theorem 1

Let $\tilde{H}_p \in \mathbb{R}^{L \times k_p}$ and $\tilde{H}_r \in \mathbb{R}^{L \times k_r}$ be stacked representations from models $M_p$ and $M_r$ respectively, where $r < p$. Among all representations of the form $[\tilde{H}_r, X]$ where $X \in \mathbb{R}^{L \times (k_p - k_r)}$, the representa

Figures (4)

  • Figure 1: Overview of Reverse Distillation Large protein language models (e.g., ESM-2$_{3B}$) entangle representations of diverse features in a single representational space, hindering the performance of downstream linear probes. Reverse distillation constructs a product space by preserving the smaller model's representation (capturing more conserved features) and extracting orthogonal residuals via SVD (capturing features unique to the larger model). Iterating this process across a model family yields Matryoshka-style embeddings where each prefix corresponds to a valid reverse-distilled representation at that scale. Figure created using biorender.io
  • Figure 2: Reverse Distillation embeddings capture more GO terms.(a) SAE features from the rd.35M model are enriched for more GO terms than those from the base model, indicating that they contain more functionally relevant information. (b) The sets of GO terms for each model are equally compact, measured by pairwise shortest path on the GO tree. (c) The sets of GO terms for rd.35M SAE features are significantly less general, measured by the depth of the pairwise least common ancestor on the GO tree. While features from the base model capture high-level GO terms, reverse distillation enables model features to represent distinct functions.
  • Figure 3: Ablation on Linear Mapping The plots show the performance of the rd.650M, rd.3B, and rd.15B models using both PCR and OLS and evaluated on their Spearman Correlation ($\rho$) across 37 datasets from the Protein Gym benchmark. They show that PCR consistently outperforms OLS, with overall win rates of 58.8%, 60.3%, and 63.2% respectively. This suggests that isolating signal from noise in the lower-scale representations is critical for effective cross-scale distillation
  • Figure 4: Ablation on Matryoshka property: (Top) Comparison of Spearman correlation $\rho$ on ProteinGym datasets between our Reverse Distillation approach and a PCA + concatenation baseline. While the baseline provides a theoretical upper bound for linear subspace decomposition, (Bottom) highlights its critical limitation: unlike our method, it relies on fresh projections at each scale and therefore cannot preserve structural consistency or scalability across the 650M, 3B, and 15B models.

Theorems & Definitions (3)

  • Definition 1: Reverse Distillation Decomposition
  • Theorem 1: Optimal Constrained Approximation
  • proof