Skip to content

Commit

Permalink
Adding ProtBERT (deepchem#3985)
Browse files Browse the repository at this point in the history
* Added sample protein classification dataset

* Modified feature name from smiles to protein

* Added prot bert and unit tests

* Added ProtBERT to __init__.py

* Moved ProtBERT into HF dependency

* Fixed formatting

* Added docstring and type annotations

* Fixed formatting

* Added sample protein classification dataset and overfit test case for ProtBERT

* Removed comments

* Fixed formatting

* Added type annotations to new variables

* Added ProtBERT to docs

* Added detailed model description
  • Loading branch information
Shiva-sankaran authored Jun 19, 2024
1 parent d3491a0 commit fe3b2f5
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 0 deletions.
11 changes: 11 additions & 0 deletions deepchem/models/tests/assets/example_protein_classification.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Compound ID,outcome,protein
Q9H400,0,M G L P V S W A P P A L W V L G C C A L L L S L W A L C T
Q86XT9,0,M E V L E E P A P G P G G A D A A E R R G L R R L L L S
P83456,0,M M K T L S S G N C T L N V P A K N S Y R M V V L G A
P33527,0,M D P S K Q G T L N R V E N S V Y R T A F K L R S V Q T L C Q L D L M
Q8VCH2,0,M E A C S S K T S L L L H S P L R T I P K L R
Q95QW4,0,M S D Y F T F P K Q E N G G I S K Q P A T P G S T R S S S R N L
O97148,0,M S G E D G P A A G P G A A A
Q8WNR0,1,M A A F T G T T D K C K A C D K T V Y V M D L M T L E G M
P43144,1,M F S G K V R A F I D E E L F H S N R N N S S D G L S L D T
Q8JZM4,1,M K R G I R R D P F R K R K L G
2 changes: 2 additions & 0 deletions deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,7 @@
try:
from deepchem.models.torch_models.hf_models import HuggingFaceModel
from deepchem.models.torch_models.chemberta import Chemberta
from deepchem.models.torch_models.prot_bert import ProtBERT

except ModuleNotFoundError as e:
logger.warning(f'Skipped loading modules with transformers dependency. {e}')
115 changes: 115 additions & 0 deletions deepchem/models/torch_models/prot_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from transformers import BertForMaskedLM, BertForSequenceClassification, BertTokenizer, BertConfig
import torch.nn as nn
from deepchem.models.torch_models.hf_models import HuggingFaceModel
from typing import Union


class ProtBERT(HuggingFaceModel):
"""
ProtBERT model[1].
ProtBERT model is based on BERT architecture and the current implementation
supports only MLM pretraining and classification mode, as described by the
authors in HuggingFace[2]. For classfication we currently only support
Logistic regression and a simple Feed forward neural network.
The model converts the input protein sequence into a vector through a trained BERT tokenizer, which is then
processed by the corresponding model based on the task. `BertForMaskedLM` is used to facilitate the MLM
pretraining task. For the sequence classification task, we follow `BertForSequenceClassification` but change
the classifier to either a logistic regression (LogReg) or a feed-forward neural network (FFN), depending on
the specified `cls_name`. The FFN is a simple 2-layer network with 512 as the hidden dimension.
Examples
--------
>>> import os
>>> import tempfile
>>> tempdir = tempfile.mkdtemp()
>>> # preparing dataset
>>> import pandas as pd
>>> import deepchem as dc
>>> protein = ["MPCTTYLPLLLLLFLLPPPSVQSKV","SSGLFWMELLTQFVLTWPLVVIAFL"]
>>> labels = [0,1]
>>> df = pd.DataFrame(list(zip(protein, labels)), columns=["protein", "task1"])
>>> with dc.utils.UniversalNamedTemporaryFile(mode='w') as tmpfile:
... df.to_csv(tmpfile.name)
... loader = dc.data.CSVLoader(["task1"], feature_field="protein", featurizer=dc.feat.DummyFeaturizer())
... dataset = loader.create_dataset(tmpfile.name)
>>> # pretraining
>>> from deepchem.models.torch_models.prot_bert import ProtBERT
>>> pretrain_model_dir = os.path.join(tempdir, 'pretrain-model')
>>> model_path = 'Rostlab/prot_bert'
>>> pretrain_model = ProtBERT(task='mlm', HG_model_path=model_path, n_tasks=1, model_dir=pretrain_model_dir) # mlm pretraining
>>> pretraining_loss = pretrain_model.fit(dataset, nb_epoch=1)
>>> del pretrain_model
>>> finetune_model_dir = os.path.join(tempdir, 'finetune-model')
>>> finetune_model = ProtBERT(task='classification', HG_model_path=model_path, n_tasks=1, model_dir=finetune_model_dir)
>>> finetune_model.load_from_pretrained(pretrain_model_dir)
>>> finetuning_loss = finetune_model.fit(dataset, nb_epoch=1)
>>> # prediction and evaluation
>>> result = finetune_model.predict(dataset)
>>> eval_results = finetune_model.evaluate(dataset, metrics=dc.metrics.Metric(dc.metrics.accuracy_score))
References
----------
.. [1] Elnaggar, Ahmed, et al. "Prottrans: Toward understanding the language of life through self-supervised learning." IEEE transactions on pattern analysis and machine intelligence 44.10 (2021): 7112-7127.
.. [2] https://huggingface.co/Rostlab/prot_bert
"""

def __init__(self,
task: str,
model_path: str = 'Rostlab/prot_bert',
n_tasks: int = 1,
cls_name: str = "LogReg",
**kwargs) -> None:
"""
Parameters
----------
task: str
The task defines the type of learning task in the model. The supported tasks are
- `mlm` - masked language modeling commonly used in pretraining
- `classification` - use it for classification tasks
model_path: str
Path to the HuggingFace model
n_tasks: int, default 1
Number of prediction targets for a multitask learning model
cls_name: str
The classifier head to use for classification mode. Currently only supports "FFN" and "LogReg"
"""
self.n_tasks: int = n_tasks
tokenizer: BertTokenizer = BertTokenizer.from_pretrained(
model_path, do_lower_case=False)
protbert_config: BertConfig = BertConfig.from_pretrained(
pretrained_model_name_or_path=model_path,
vocab_size=tokenizer.vocab_size)
model: Union[BertForMaskedLM, BertForSequenceClassification]
if task == "mlm":
model = BertForMaskedLM.from_pretrained(model_path)
elif task == "classification":
cls_head: Union[nn.Linear, nn.Sequential]
if n_tasks == 1:
protbert_config.problem_type = 'single_label_classification'
else:
protbert_config.problem_type = 'multi_label_classification'

if (cls_name == "LogReg"):
cls_head = nn.Linear(1024, n_tasks + 1)
elif (cls_name == "FFN"):
cls_head = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(),
nn.Linear(512, n_tasks + 1))

else:
raise ValueError('Invalid classifier: {}.'.format(cls_name))

model = BertForSequenceClassification.from_pretrained(
model_path, config=protbert_config)
model.classifier = cls_head

else:
raise ValueError('Invalid task specification')
super().__init__(model=model, task=task, tokenizer=tokenizer, **kwargs)
18 changes: 18 additions & 0 deletions deepchem/models/torch_models/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,21 @@ def smiles_multitask_regression_dataset():
featurizer=dc.feat.DummyFeaturizer())
dataset = loader.create_dataset(input_file)
return dataset


@pytest.fixture
def protein_classification_dataset(tmpdir):
protein = [
"MGLPVSWAPPALWVLGCCALLLSLWA",
"MEVLEEPAPGPGGADAAERRGLRRL",
]
labels = [0, 1]
df = pd.DataFrame(list(zip(protein, labels)), columns=["protein", "task1"])
filepath = os.path.join(tmpdir, 'protein.csv')
df.to_csv(filepath)

loader = dc.data.CSVLoader(["task1"],
feature_field="protein",
featurizer=dc.feat.DummyFeaturizer())
dataset = loader.create_dataset(filepath)
return dataset
136 changes: 136 additions & 0 deletions deepchem/models/torch_models/tests/test_prot_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os

import deepchem as dc
import pytest

try:
import torch
from deepchem.models.torch_models.prot_bert import ProtBERT
except ModuleNotFoundError:
pass


@pytest.mark.torch
def test_prot_bert_pretraining_mlm(protein_classification_dataset):
model_path = 'Rostlab/prot_bert'
model = ProtBERT(task='mlm', HG_model_path=model_path, n_tasks=1)
loss = model.fit(protein_classification_dataset, nb_epoch=1)
assert loss


@pytest.mark.torch
def test_prot_bert_finetuning(protein_classification_dataset):

model_path = 'Rostlab/prot_bert'

model = ProtBERT(task='classification',
HG_model_path=model_path,
n_tasks=1,
cls_name="LogReg")
loss = model.fit(protein_classification_dataset, nb_epoch=1)
eval_score = model.evaluate(protein_classification_dataset,
metrics=dc.metrics.Metric(
dc.metrics.accuracy_score))
assert eval_score, loss
prediction = model.predict(protein_classification_dataset)
assert prediction.shape == (protein_classification_dataset.y.shape[0], 2)

model = ProtBERT(task='classification',
HG_model_path=model_path,
n_tasks=1,
cls_name="FFN")
loss = model.fit(protein_classification_dataset, nb_epoch=1)
eval_score = model.evaluate(protein_classification_dataset,
metrics=dc.metrics.Metric(
dc.metrics.accuracy_score))
assert eval_score, loss
prediction = model.predict(protein_classification_dataset)
assert prediction.shape == (protein_classification_dataset.y.shape[0], 2)


@pytest.mark.torch
def test_protbert_load_from_pretrained(tmpdir):
pretrain_model_dir = os.path.join(tmpdir, 'pretrain')
finetune_model_dir = os.path.join(tmpdir, 'finetune')
model_path = 'Rostlab/prot_bert'
pretrain_model = ProtBERT(task='mlm',
HG_model_path=model_path,
n_tasks=1,
model_dir=pretrain_model_dir)
pretrain_model.save_checkpoint()

finetune_model = ProtBERT(task='classification',
HG_model_path=model_path,
n_tasks=1,
cls_name="LogReg",
model_dir=finetune_model_dir)
finetune_model.load_from_pretrained(pretrain_model_dir)

# check weights match
pretrain_model_state_dict = pretrain_model.model.state_dict()
finetune_model_state_dict = finetune_model.model.state_dict()

pretrain_base_model_keys = [
key for key in pretrain_model_state_dict.keys() if 'bert' in key
]
matches = [
torch.allclose(pretrain_model_state_dict[key],
finetune_model_state_dict[key])
for key in pretrain_base_model_keys
]

assert all(matches)


@pytest.mark.torch
def test_protbert_save_reload(tmpdir):
model_path = 'Rostlab/prot_bert'
model = ProtBERT(task='classification',
HG_model_path=model_path,
n_tasks=1,
cls_name="FFN",
model_dir=tmpdir)
model._ensure_built()
model.save_checkpoint()

model_new = ProtBERT(task='classification',
HG_model_path=model_path,
n_tasks=1,
cls_name="FFN",
model_dir=tmpdir)
model_new.restore()

old_state = model.model.state_dict()
new_state = model_new.model.state_dict()
matches = [
torch.allclose(old_state[key], new_state[key])
for key in old_state.keys()
]

# all keys values should match
assert all(matches)


@pytest.mark.torch
def test_protbert_overfit():
current_dir = os.path.dirname(os.path.abspath(__file__))

featurizer = dc.feat.DummyFeaturizer()
tasks = ["outcome"]
loader = dc.data.CSVLoader(tasks=tasks,
feature_field="protein",
featurizer=featurizer)
dataset = loader.create_dataset(
os.path.join(current_dir,
"../../tests/assets/example_protein_classification.csv"))
model_path = 'Rostlab/prot_bert'
finetune_model = ProtBERT(task='classification',
HG_model_path=model_path,
n_tasks=1,
cls_name="FFN",
batch_size=1,
learning_rate=1e-5)
classification_metric = dc.metrics.Metric(dc.metrics.accuracy_score)
finetune_model.fit(dataset, nb_epoch=10)
eval_score = finetune_model.evaluate(dataset, [classification_metric])
assert eval_score[classification_metric.name] > 0.9
6 changes: 6 additions & 0 deletions docs/source/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,12 @@ Chemberta
.. autoclass:: deepchem.models.torch_models.chemberta.Chemberta
:members:

ProtBERT
---------

.. autoclass:: deepchem.models.torch_models.prot_bert.ProtBERT
:members:

Trainer
=======

Expand Down

0 comments on commit fe3b2f5

Please sign in to comment.