Fine-tuning protein language models boosts predictions across diverse tasks

Fine-tuning is mostly successfulWe trained 615 individual prediction methods (295 for fine-tuning, 320 using frozen embeddings from pre-trained pLMs – protein Language Models) comprising eight models (Table 1), each trained on eight different data sets (Table 2). We trained each model-task combination multiple times with different random seeds and all results constituted averages over those runs. The corresponding validation set selected the best training stop. For each prediction task (PT), we compared the performance between fine-tuning and pre-training (1). For ProtT5-XL-U5014 (labeled ProtT5) and all five tested ESM2 versions16 (differing in parameter size between 8 M (8*10^6) and 3B (3*10^9)), not all improvements were statistically significant within the 95% confidence interval (CI: Methods). Nevertheless, supervised fine-tuning numerically increased performance for almost all combinations (Fig. 1, detailed results in supplementary online material (SOM) Tables S1–S6). The exceptions were ESM2-150M applied to Stability prediction, and both Ankh15 models. Ankh gained significantly by fine-tuning only for the mutational landscape data (GFP, AAV, and GB1: blue in Fig. 1).Table 2 Task-specific datasets*Fig. 1: Fine-tuning improved for most pLMs and tasks.Asterisks (*) mark fully fine-tuned models; the others were LoRA-optimized (SOM Fig. S1). Values reflect percentage differences between the fine-tuned and pre-trained models (1) for the eight prediction tasks (x-axis). We had to use different performance measures, namely the Spearman rank correlation (GFP, AAV, GB1, stability, meltome and disorder), 10-class accuracy (Q10: sub-cellular location), and 3-class per-residue accuracy (Q3: secondary structure). Each tile compares fine-tuning to raw embeddings for one task. Blue tiles mark statistically significant increases (>1.96 standard errors; fine-tuning better), yellow tiles mark statistically insignificant changes (0 lies within the error margins of ±1.96 stderr) and for red tiles supervised fine-tuning significantly decreased performance. Error estimates (±percentage values) represent the 95% confidence intervals (CI, Methods). Source data are provided as a Source Data file.For these data, performance relied less on transfer from model pretraining (Fig. S7) and mainly depended on the underlying transformer architecture. This might explain why Ankh performed similarly to ProtT5 and the ESM2. For the diverse data sets, this was not the case. Two major factors differentiate Ankh from the other pLMs. Firstly, the T549 masked span pre-training differs from that of BERT-like50 objective used for the other models. Secondly, the training procedure and architecture of Ankh was optimized using data (GFP, GB1, subcellular location, and secondary structure) also utilized in this work15. This might have reduced the ability to fine-tune these models.For five of the 64 pLM/task combinations (tiles in Fig. 1), fine-tuning performed worse. The observation ESM2-150M on stability (Fig. 1 red tile) originated from instability in training picking a suboptimal model (Fig. S5). The other four originated from the Ankh pLM family on disorder and secondary structure. We were not able to track down a root cause here but suspect that the different nature of the pre-training plays a role.LoRA was competitive with alternative PEFT methodsFor ProtT5 and sub-cellular location prediction, we compared three parameter-efficient fine-tuning methods to LoRA46. Not having sufficient resources to do this analysis for all prediction tasks/pLMs, we chose this problem due to its limit in size and because of the success of fine-tuning on this problem (configuration in Method and Fig. 2). The fraction of trained model parameters were 0.25% for LoRA, 0.28% for DoRA51, 0.12% for IA352 and 0.5% for Prefix tuning53. Despite these differences, runtimes for training and testing (inference) were within ±10% between methods, except for DoRA which was about 30% slower than the other three. In terms of prediction performance, LoRA and DoRA outperformed IA3 and Prefix-tuning (Fig. 2). Overall, all fine-tuning methods improved, on average, over pre-trained embeddings (61.3% from Table S5). As no method improved significantly over the well-established LoRA, we used it throughout our experiments. Of course, these results for a single model and dataset must not hold true in general. We encourage to explore parameter efficient fine-tuning of pLMs, utilizing new combinations of high-quality datasets, state-of-the-art models, and PEFT methods in future work and hope the notebooks made available by us help to pursue this research more easily.Fig. 2: Comparison of different PEFT methods.ProtT5 model assessed on the prediction of sub-cellular location (x-axis: 10-state per-protein accuracy Q10). Mean values as well as 95% confidence intervals are computed from three training re-runs for each of the four PEFT methods: LoRA46, DoRA51, IA352, and Prefix-tuning53. We used the same configuration for LoRA and DoRA. The IA3 target modules were the key, value, and feed-forward layers. Prefix-tuning used 20 virtual tokens with 1024 dimensions to fit the ProtT5 dimensions. Circles represent individual training results. Differences between methods are mostly insignificant, with all four numerically outperforming the pre-trained embedding predictor on average (dashed grey line). Source data are provided as a Source Data file.Insignificant gain for secondary structure predictionFor per-residue, three-class secondary structure prediction (helix, strand, other), fine-tuning improved only slightly (Fig. 2a; up to 1.2 percentage points for CASP1254 and NEW36414). We confirmed this for the general-purpose ProtT514 and the bilingual, structure-tuned ProstT536. Two effects might have hindered substantial improvement. Firstly, secondary structure might already have been captured in unsupervised pre-training. In fact, embeddings already capture some aspects of inter-residue contact formation10,20. Secondly, performance may have reached an upper limit55. One limitation of the benchmark is highlighted by the two data sets (CASP1254 and NEW36414). Both were introduced to gauge the performance for unknown proteins. Other than that CASP12 is much smaller (12 proteins vs. 364) implying higher statistical errors, there seems no a priori reason for choosing one over the other, and no method compared here is expected to have any systematic bias toward any of the two. Thus, the difference between both should estimate the statistical error. In other words, bootstrapping error estimates should be similar to the difference between the two sets. This was not the case at all (Fig. 3a: differences between CASP12 and NEW364 exceeded standard errors marked by distributions). Arguably, secondary structure prediction assessment is the best-solved task in protein structure prediction since decades33,55. Even for this relatively trivial problem, such a simple dichotomy seems not easily resolvable. In fact, the standard mantra: larger data sets without redundancy appears not to solve this dichotomy. These results underscore how difficult it is to just plug in standard data sets to assess the performance of prediction methods without updating data and adapting it to advancing methods.Fig. 3: Disorder prediction better, secondary structure prediction not.Mean values and 95% confidence intervals (CI) were estimated through bootstrapping (n = 10 for a, n = 25 for b), violin plots reflect the data distribution. Source data are provided as a Source Data file. a Values for the pre-trained models (ProtT514 and ProstT536) taken from literature36 (no CI available for CASP12) and marked by asterisk (*); fine-tuning in green, pre-trained embeddings in orange. We included two previously used data sets (CASP1254 and NEW36414) to highlight the limitation of benchmarks. b Intrinsically disordered residues can be proxied by CheZOD scores56. The x-axis shows the Spearman correlation between experimental and predicted CheZOD scores for six methods. Values marked by asterisks (*) taken from the literature19. Fine-tuning results in green, pLM-based without MSA (SETH19) in orange, MSA-based SOTA in gray56,72, and MSA-based AlphaFold276 in blue.Fine-tuning boosted disorder predictionPLM-based SETH19 reached the level of MSA-based SOTA methods, such as ODiNPred56 in the prediction of per-residue protein disorder as described by CheZOD scores56. SETH inputs ProtT5 embeddings into a two-layer CNN.Keeping those hyper-parameters and adding LoRA fine-tuning (dubbed SETH-LoRA), improved performance by 2.2 percentage points (from Spearman 0.72 to 0.736, Fig. 3b). Fine-tuning the much smaller 150 M parameter ESM2 model (Spearman: 0.742) improved overall solutions compared (Fig. 3b), including its larger counterparts (ESM2 with 650M/3B parameters, Table S4). Compared to SETH-LoRA where only 2.5 million out of its 1.2 billion parameters are trained, for ESM2-150M all parameters were fine-tuned. Both approaches (2.5 m for ProtT5 vs 150 m for ESM2) performed similarly (Fig. 3b).LoRA topped pooling for subcellular locationMost predictions of subcellular location input signals averaged over entire proteins (e.g., amino acid composition). Embedding-based solutions do this through pooling, i.e., through embeddings derived from averaging over all intrinsic residue-level embeddings14. Light Attention (LA) substantially improves over such coarse-grained averaging by learning the optimal per-residue signal and combining this with the average34. LoRA fine-tuning combined the advantage of a small model (fewer free parameters) with the learned, weighted averaging of LA. Thereby, LoRA fine-tuning numerically surpassed LA, although the difference was statistically significant only at an 88% confidence interval (CI and not at the more common CI = 95% Table S9).Fine-tuning better-captured effects of mutationsFor predicting mutation landscapes (Fig. 1 leftmost three columns) fine-tuning any pLM succeeded substantially. As differences between fine-tuned models were small (Fig. S3), we averaged performance across all fine-tuned pLMs (Fig. 4, for individual values refer to Table S2), and compared to homology-based inference (HBI, using MMseqs257 search) and to reference-free analysis (RFA58). RFA fits a decent first-order model for the fitness landscape reflecting some mutations for GB1 (protein G domain B159; all possible variants for four residues, i.e., at four positions). For AAV260 (adeno-associated virus 2) for which a much larger 28-residue window was mutated, RFA performed less well. For GFP (green fluorescent protein61) the RFA analyses failed because some specific substitutions XnY (amino acid X at position n mutated to Y) occurred only in the test set. The fact that smaller and larger models performed alike on these tasks raised the prospect of using small, fine-tuned pLMs as computationally affordable, high-quality solutions for protein engineering.Fig. 4: Simple methods limited for mutational effects.Blue: average performance across all fine-tuned pLMs (mean values with 95% CI, n = 24) with violin plots providing the underlying distribution; Gray: two simple baseline methods: Homology (HBI): MMseqs255 inferred the fitness value of test set proteins from most sequence-similar homolog in the training set. RFA (reference-free analysis56) fits models based on the assumption, that most of the mutational effects can be described as the sum of low-order effects. Source data are provided as a Source Data file.LoRA was substantially faster for larger modelsThe main drivers for the amount of computational resources required for model training were the parameter sizes of pLMs along with quadratic scaling of the attention mechanism (more resources for longer proteins). More recent GPUs used for LLM training (anything beyond 40GB of memory) will have sufficient memory to allow usage of all pLMs tested here. For less powerful hardware (Fig. 5b), mixed precision training nearly halved the required GPU memory without performance loss (both Ankh models were exceptional, as they do not support mixed precision training). Where GPU memory still was a bottleneck, we applied gradient accumulation to reduce the actual on-device batch size as far as needed. When even an on-device batch size of 1 was insufficient, we used DeepSpeed to offload the optimizer and potential parameters to CPU-reduced GPU memory requirements further. As a trade-off, both gradient accumulation and CPU offloading slowed down training. Hence, both should be used cautiously. Implementing all these measures, we could fine-tune most pLMs tested here even on older GPUs with as little as 8GB memory (Fig. 5b). Unintuitively, both full model fine-tuning and parameter-efficient LoRA fine-tuning required the same amount of GPU memory and only differed in training speed (Fig. 5a) when CPU offloading was utilized. Embedding creation required much less GPU memory rendering it feasible even for datasets with very long sequences (Fig. S1).Fig. 5: Fine-tuning training speed and GPU requirements.a Relative training speed of full fine-tuning (blue) and LoRA (red) is shown on a logarithmic scale, ProtT5 LoRA fine-tuning served as reference speed with value of 1 (x). The resulting speed-up for each model (olive) is shown on a normal scale. Experiments were performed with arbitrary sequences of length 1024 in a per-protein setting. For the smallest model (ESM2 8 M), LoRA fine-tuning was marginally slower than training the entire model. The larger the model, the more advantageous LoRA became. For the largest model (ESM2 3B), LoRA was about 4.5-fold faster. Panel b shows the maximum sequence length before the GPU runs out of memory (for 8, 18, and 24GB GPUs). All values obtained for memory-efficient training (mixed precision training, gradient accumulation with on-device batch size 1, and DeepSpeed CPU offloading). Experiments were done for per-protein predictions, but memory requirements for per-residue training will be similar. Results valid for full-model and LoRA fine-tuning. Source data are provided as a Source Data file.Fine-tuning recipeTo ease the simplicity of fine-tuning pLMs for your data set, we added the following recommendations. Before starting model training, dataset splits to measure model generalization and prevent over-estimating performance23,28 are essential. First off: you need at least three data sets: training (optimizing weights), cross-training/validation (optimization of hyper-parameters, e.g., to decide between CNN and ANN), and testing (only touched to estimate performance). Typically, all entities in the test set (e.g. proteins) should adhere to the same split required between training/validation and testing. In particular, proteins have to be made non-redundant. This requires clustering by sequence identity using standard alignment methods such as MMseqs257 (simpler solutions tend to lead more likely to information leakage). For structure-related tasks, redundancy is best removed through 3D clustering as realized by Foldseek62. To optimize the prediction of mutational landscapes for a single protein, it might be best to train on k-mers with k = 1 (single amino acid variants) and test on k-mers with k > 122,23 (although this approach might focus more on avoiding over-fitting than on generating the best optimal model).To predict landscapes of mutational effects for specific proteins, a challenge encountered in protein engineering, we recommend to first fine-tune a smaller pLM (pre-trained embeddings were limited: Fig. S3). Optimize hyperparameters and head architectures on this smaller model. If done, you could explore additional improvements from larger pLMs. For the fine-tuning on diverse tasks, larger mostly outperformed smaller models (Figs. S3 & S4). Therefore, starting with raw embedding-based solutions to identify the best model and to then investigate different prediction heads appeared better than optimizing the fine-tuning directly. Applying parameter-efficient LoRA fine-tuning and optimizing hyperparameters for the selected model afterward, will probably lead to an even better solution.For our tasks, over-fitting mostly originated from data set characteristics. On the one hand, given a data set prone to over-fitting (e.g. too small, uninformative, or complex), neither hyperparameter nor model optimization could fully avoid the trap. On the other hand, for data sets not prone to over-fitting the training of fine-tuning was stable regardless of other factors. These factors affected raw embedding-based and fine-tuned models alike. Avoiding imbalanced datasets, providing sufficient high-quality training data, and choosing smaller models for limited data sets could mitigate over-fitting (SOM Sections 9 and 10).On the computational side, we recommend mixed precision training. Gradient accumulation and DeepSpeed’s CPU-offloading should only be reserved to avoid memory constraints. With all these measures in place, a single 16 GB GPU enables fine-tuning in many cases (Fig. 5). Comparing different common PEFT methods (Fig. 2) did not suggest a clear winner. The established LoRA46 method was among the best solutions and was stable across our experiments. The codebase provided by us simplifies experimenting with different PEFT approaches as it utilizes the Hugging Face PEFT63 framework. We encourage you to compare different PEFT methods for your specific use cases. PEFT is memory efficient, but CPU-offloading could achieve the same. However, PEFT is also most compute-efficient for larger pLMs (Fig. 5a); it stabilizes training (SOM Section 5) and it renders saving model checkpoints orders of magnitude more memory efficient, as only the trained parameters need to be stored. Thus, we recommend LoRA fine-tuning for all models larger than ESM2 150 M (Fig. S2). We see little reason not to fully fine-tune smaller models.In our hands, different random initialization seeds significantly altered results. These random variations reached the magnitude effect of hyperparameters or even model selection.Concluding thoughtsWe applied fine-tuning38,46 to a diversity of prediction tasks clearly showing improvements, on average. The extent of this improvement varied by task and pLM/model and was impacted by the amount of training data (Fig. S10), dataset balance (Table S8), models size (Fig. S3), and initial representation quality (Fig. S4).Overall, our results revealed the gains initially observed in NLP from supervised task-specific fine-tuning of LLMs39,40 to also apply to large protein LMs (pLMs). Supervised fine-tuning unlocks additional degrees of freedom in the predictor models. The last hidden layer has been optimized for the unsupervised pre-training objective (learning to reproduce masked sequences). This optimization might be suboptimal for any downstream task41,42. PEFT (or finetuning in general) enables information from middle layers to flow to the last layer, making it accessible to downstream tasks. Additionally, for per-protein predictions, the LoRA optimization may have learned weighted pooling of the last hidden layer, and that improved significantly over average pooling34. Lastly, the transformer models might extract additional information directly from the task-specific training. Randomly initialized smaller ESM2 models supported this view (Fig. S7, Table S17).Therefore, we suggest to add supervised fine-tuning whenever applying transfer-learning, i.e., when inputting pLM embeddings into subsequent supervised prediction tasks. Our results suggested that you will most often benefit from this. To ease this additional step, we provided all our resources and added step-by-step recommendations.

Hot Topics

Related Articles