Table of Contents
Fetching ...

Training Compute-Optimal Protein Language Models

Xingyi Cheng, Bo Chen, Pan Li, Jing Gong, Jie Tang, Le Song

TL;DR

This investigation explores optimally training protein language models, an area of significant interest in biological research where guidance on best practices is limited, and obtains the scaling laws of CLM and MLM on Transformer, tailored to the specific characteristics of protein sequence data.

Abstract

We explore optimally training protein language models, an area of significant interest in biological research where guidance on best practices is limited. Most models are trained with extensive compute resources until performance gains plateau, focusing primarily on increasing model sizes rather than optimizing the efficient compute frontier that balances performance and compute budgets. Our investigation is grounded in a massive dataset consisting of 939 million protein sequences. We trained over 300 models ranging from 3.5 million to 10.7 billion parameters on 5 to 200 billion unique tokens, to investigate the relations between model sizes, training token numbers, and objectives. First, we observed the effect of diminishing returns for the Causal Language Model (CLM) and that of overfitting for the Masked Language Model~(MLM) when repeating the commonly used Uniref database. To address this, we included metagenomic protein sequences in the training set to increase the diversity and avoid the plateau or overfitting effects. Second, we obtained the scaling laws of CLM and MLM on Transformer, tailored to the specific characteristics of protein sequence data. Third, we observe a transfer scaling phenomenon from CLM to MLM, further demonstrating the effectiveness of transfer through scaling behaviors based on estimated Effectively Transferred Tokens. Finally, to validate our scaling laws, we compare the large-scale versions of ESM-2 and PROGEN2 on downstream tasks, encompassing evaluations of protein generation as well as structure- and function-related tasks, all within less or equivalent pre-training compute budgets.

Training Compute-Optimal Protein Language Models

TL;DR

This investigation explores optimally training protein language models, an area of significant interest in biological research where guidance on best practices is limited, and obtains the scaling laws of CLM and MLM on Transformer, tailored to the specific characteristics of protein sequence data.

Abstract

We explore optimally training protein language models, an area of significant interest in biological research where guidance on best practices is limited. Most models are trained with extensive compute resources until performance gains plateau, focusing primarily on increasing model sizes rather than optimizing the efficient compute frontier that balances performance and compute budgets. Our investigation is grounded in a massive dataset consisting of 939 million protein sequences. We trained over 300 models ranging from 3.5 million to 10.7 billion parameters on 5 to 200 billion unique tokens, to investigate the relations between model sizes, training token numbers, and objectives. First, we observed the effect of diminishing returns for the Causal Language Model (CLM) and that of overfitting for the Masked Language Model~(MLM) when repeating the commonly used Uniref database. To address this, we included metagenomic protein sequences in the training set to increase the diversity and avoid the plateau or overfitting effects. Second, we obtained the scaling laws of CLM and MLM on Transformer, tailored to the specific characteristics of protein sequence data. Third, we observe a transfer scaling phenomenon from CLM to MLM, further demonstrating the effectiveness of transfer through scaling behaviors based on estimated Effectively Transferred Tokens. Finally, to validate our scaling laws, we compare the large-scale versions of ESM-2 and PROGEN2 on downstream tasks, encompassing evaluations of protein generation as well as structure- and function-related tasks, all within less or equivalent pre-training compute budgets.

Paper Structure

This paper contains 30 sections, 12 equations, 15 figures, 10 tables.

Figures (15)

  • Figure 1: Learning curves for UR50/S and UniMeta200B. Training loss and validation PPL, OOD test PPL, were tracked over 200 billion training tokens for both the 150M and 3B models. As we scaled the model from 150M to 3B, we observed diminishing returns on CLM (First line) and a tendency to overfit on MLM (Second line) when repeating the Uniref50 (UR50/S) dataset. We totally evaluate 3 repeating methods on MLM 3B models, all of which present overfitting (see Appendix \ref{['app:repeat']}).
  • Figure 2: IsoFLOPs curves and parametric fit for CLM and MLM. We selected training tokens to ensure a uniform final FLOP count for different model sizes. The lowest loss of each curve revealed an optimal model size for a FLOP budget (above). We use these rainbow points at the valley to plot the efficient frontier for estimating the optimal model size and training tokens for scaling models (below). The interval range was estimated by model points with similar loss.
  • Figure 3: Compute allocation for two objectives with the same model size.
  • Figure 4: Left: The upper graph compares validation loss of CLM trained from scratch with those transferred from MLM, showing diminishing transfer benefits as model size increases. The lower graph depicts increased benefits for MLM from pre-trained CLM with larger sizes, indicating scale-dependent efficiency gains. Right: Shows loss curves for CLM and MLM across different FLOPs, emphasizing the efficient frontiers (or Pareto Frontier) from various transfer strategies. It highlights that the benefits of transferring from CLM to MLM grow with model size, reflecting a scale-dependent synergy between training objectives.
  • Figure 5: Left: Valid perplexity of % compute allocated for the CLM pre-training. For instance, % compute indicates first training on CLM and then the rest compute fine-tuning on MLM. The optimal CLM pre-training % compute range with [10, 20]. And the fitted $D_t / (D_t + D_f)$ drops in the optimal loss range. Right: Comparison of validation perplexity for models trained from scratch (red) and those fine-tuned from a pre-trained CLM (green), demonstrating that fine-tuning from a CLM reduces perplexity with similar or even fewer tokens.
  • ...and 10 more figures