forked from deepchem/deepchem
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added sample protein classification dataset * Modified feature name from smiles to protein * Added prot bert and unit tests * Added ProtBERT to __init__.py * Moved ProtBERT into HF dependency * Fixed formatting * Added docstring and type annotations * Fixed formatting * Added sample protein classification dataset and overfit test case for ProtBERT * Removed comments * Fixed formatting * Added type annotations to new variables * Added ProtBERT to docs * Added detailed model description
- Loading branch information
1 parent
d3491a0
commit fe3b2f5
Showing
6 changed files
with
288 additions
and
0 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
deepchem/models/tests/assets/example_protein_classification.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
Compound ID,outcome,protein | ||
Q9H400,0,M G L P V S W A P P A L W V L G C C A L L L S L W A L C T | ||
Q86XT9,0,M E V L E E P A P G P G G A D A A E R R G L R R L L L S | ||
P83456,0,M M K T L S S G N C T L N V P A K N S Y R M V V L G A | ||
P33527,0,M D P S K Q G T L N R V E N S V Y R T A F K L R S V Q T L C Q L D L M | ||
Q8VCH2,0,M E A C S S K T S L L L H S P L R T I P K L R | ||
Q95QW4,0,M S D Y F T F P K Q E N G G I S K Q P A T P G S T R S S S R N L | ||
O97148,0,M S G E D G P A A G P G A A A | ||
Q8WNR0,1,M A A F T G T T D K C K A C D K T V Y V M D L M T L E G M | ||
P43144,1,M F S G K V R A F I D E E L F H S N R N N S S D G L S L D T | ||
Q8JZM4,1,M K R G I R R D P F R K R K L G |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from transformers import BertForMaskedLM, BertForSequenceClassification, BertTokenizer, BertConfig | ||
import torch.nn as nn | ||
from deepchem.models.torch_models.hf_models import HuggingFaceModel | ||
from typing import Union | ||
|
||
|
||
class ProtBERT(HuggingFaceModel): | ||
""" | ||
ProtBERT model[1]. | ||
ProtBERT model is based on BERT architecture and the current implementation | ||
supports only MLM pretraining and classification mode, as described by the | ||
authors in HuggingFace[2]. For classfication we currently only support | ||
Logistic regression and a simple Feed forward neural network. | ||
The model converts the input protein sequence into a vector through a trained BERT tokenizer, which is then | ||
processed by the corresponding model based on the task. `BertForMaskedLM` is used to facilitate the MLM | ||
pretraining task. For the sequence classification task, we follow `BertForSequenceClassification` but change | ||
the classifier to either a logistic regression (LogReg) or a feed-forward neural network (FFN), depending on | ||
the specified `cls_name`. The FFN is a simple 2-layer network with 512 as the hidden dimension. | ||
Examples | ||
-------- | ||
>>> import os | ||
>>> import tempfile | ||
>>> tempdir = tempfile.mkdtemp() | ||
>>> # preparing dataset | ||
>>> import pandas as pd | ||
>>> import deepchem as dc | ||
>>> protein = ["MPCTTYLPLLLLLFLLPPPSVQSKV","SSGLFWMELLTQFVLTWPLVVIAFL"] | ||
>>> labels = [0,1] | ||
>>> df = pd.DataFrame(list(zip(protein, labels)), columns=["protein", "task1"]) | ||
>>> with dc.utils.UniversalNamedTemporaryFile(mode='w') as tmpfile: | ||
... df.to_csv(tmpfile.name) | ||
... loader = dc.data.CSVLoader(["task1"], feature_field="protein", featurizer=dc.feat.DummyFeaturizer()) | ||
... dataset = loader.create_dataset(tmpfile.name) | ||
>>> # pretraining | ||
>>> from deepchem.models.torch_models.prot_bert import ProtBERT | ||
>>> pretrain_model_dir = os.path.join(tempdir, 'pretrain-model') | ||
>>> model_path = 'Rostlab/prot_bert' | ||
>>> pretrain_model = ProtBERT(task='mlm', HG_model_path=model_path, n_tasks=1, model_dir=pretrain_model_dir) # mlm pretraining | ||
>>> pretraining_loss = pretrain_model.fit(dataset, nb_epoch=1) | ||
>>> del pretrain_model | ||
>>> finetune_model_dir = os.path.join(tempdir, 'finetune-model') | ||
>>> finetune_model = ProtBERT(task='classification', HG_model_path=model_path, n_tasks=1, model_dir=finetune_model_dir) | ||
>>> finetune_model.load_from_pretrained(pretrain_model_dir) | ||
>>> finetuning_loss = finetune_model.fit(dataset, nb_epoch=1) | ||
>>> # prediction and evaluation | ||
>>> result = finetune_model.predict(dataset) | ||
>>> eval_results = finetune_model.evaluate(dataset, metrics=dc.metrics.Metric(dc.metrics.accuracy_score)) | ||
References | ||
---------- | ||
.. [1] Elnaggar, Ahmed, et al. "Prottrans: Toward understanding the language of life through self-supervised learning." IEEE transactions on pattern analysis and machine intelligence 44.10 (2021): 7112-7127. | ||
.. [2] https://huggingface.co/Rostlab/prot_bert | ||
""" | ||
|
||
def __init__(self, | ||
task: str, | ||
model_path: str = 'Rostlab/prot_bert', | ||
n_tasks: int = 1, | ||
cls_name: str = "LogReg", | ||
**kwargs) -> None: | ||
""" | ||
Parameters | ||
---------- | ||
task: str | ||
The task defines the type of learning task in the model. The supported tasks are | ||
- `mlm` - masked language modeling commonly used in pretraining | ||
- `classification` - use it for classification tasks | ||
model_path: str | ||
Path to the HuggingFace model | ||
n_tasks: int, default 1 | ||
Number of prediction targets for a multitask learning model | ||
cls_name: str | ||
The classifier head to use for classification mode. Currently only supports "FFN" and "LogReg" | ||
""" | ||
self.n_tasks: int = n_tasks | ||
tokenizer: BertTokenizer = BertTokenizer.from_pretrained( | ||
model_path, do_lower_case=False) | ||
protbert_config: BertConfig = BertConfig.from_pretrained( | ||
pretrained_model_name_or_path=model_path, | ||
vocab_size=tokenizer.vocab_size) | ||
model: Union[BertForMaskedLM, BertForSequenceClassification] | ||
if task == "mlm": | ||
model = BertForMaskedLM.from_pretrained(model_path) | ||
elif task == "classification": | ||
cls_head: Union[nn.Linear, nn.Sequential] | ||
if n_tasks == 1: | ||
protbert_config.problem_type = 'single_label_classification' | ||
else: | ||
protbert_config.problem_type = 'multi_label_classification' | ||
|
||
if (cls_name == "LogReg"): | ||
cls_head = nn.Linear(1024, n_tasks + 1) | ||
elif (cls_name == "FFN"): | ||
cls_head = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(), | ||
nn.Linear(512, n_tasks + 1)) | ||
|
||
else: | ||
raise ValueError('Invalid classifier: {}.'.format(cls_name)) | ||
|
||
model = BertForSequenceClassification.from_pretrained( | ||
model_path, config=protbert_config) | ||
model.classifier = cls_head | ||
|
||
else: | ||
raise ValueError('Invalid task specification') | ||
super().__init__(model=model, task=task, tokenizer=tokenizer, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import os | ||
|
||
import deepchem as dc | ||
import pytest | ||
|
||
try: | ||
import torch | ||
from deepchem.models.torch_models.prot_bert import ProtBERT | ||
except ModuleNotFoundError: | ||
pass | ||
|
||
|
||
@pytest.mark.torch | ||
def test_prot_bert_pretraining_mlm(protein_classification_dataset): | ||
model_path = 'Rostlab/prot_bert' | ||
model = ProtBERT(task='mlm', HG_model_path=model_path, n_tasks=1) | ||
loss = model.fit(protein_classification_dataset, nb_epoch=1) | ||
assert loss | ||
|
||
|
||
@pytest.mark.torch | ||
def test_prot_bert_finetuning(protein_classification_dataset): | ||
|
||
model_path = 'Rostlab/prot_bert' | ||
|
||
model = ProtBERT(task='classification', | ||
HG_model_path=model_path, | ||
n_tasks=1, | ||
cls_name="LogReg") | ||
loss = model.fit(protein_classification_dataset, nb_epoch=1) | ||
eval_score = model.evaluate(protein_classification_dataset, | ||
metrics=dc.metrics.Metric( | ||
dc.metrics.accuracy_score)) | ||
assert eval_score, loss | ||
prediction = model.predict(protein_classification_dataset) | ||
assert prediction.shape == (protein_classification_dataset.y.shape[0], 2) | ||
|
||
model = ProtBERT(task='classification', | ||
HG_model_path=model_path, | ||
n_tasks=1, | ||
cls_name="FFN") | ||
loss = model.fit(protein_classification_dataset, nb_epoch=1) | ||
eval_score = model.evaluate(protein_classification_dataset, | ||
metrics=dc.metrics.Metric( | ||
dc.metrics.accuracy_score)) | ||
assert eval_score, loss | ||
prediction = model.predict(protein_classification_dataset) | ||
assert prediction.shape == (protein_classification_dataset.y.shape[0], 2) | ||
|
||
|
||
@pytest.mark.torch | ||
def test_protbert_load_from_pretrained(tmpdir): | ||
pretrain_model_dir = os.path.join(tmpdir, 'pretrain') | ||
finetune_model_dir = os.path.join(tmpdir, 'finetune') | ||
model_path = 'Rostlab/prot_bert' | ||
pretrain_model = ProtBERT(task='mlm', | ||
HG_model_path=model_path, | ||
n_tasks=1, | ||
model_dir=pretrain_model_dir) | ||
pretrain_model.save_checkpoint() | ||
|
||
finetune_model = ProtBERT(task='classification', | ||
HG_model_path=model_path, | ||
n_tasks=1, | ||
cls_name="LogReg", | ||
model_dir=finetune_model_dir) | ||
finetune_model.load_from_pretrained(pretrain_model_dir) | ||
|
||
# check weights match | ||
pretrain_model_state_dict = pretrain_model.model.state_dict() | ||
finetune_model_state_dict = finetune_model.model.state_dict() | ||
|
||
pretrain_base_model_keys = [ | ||
key for key in pretrain_model_state_dict.keys() if 'bert' in key | ||
] | ||
matches = [ | ||
torch.allclose(pretrain_model_state_dict[key], | ||
finetune_model_state_dict[key]) | ||
for key in pretrain_base_model_keys | ||
] | ||
|
||
assert all(matches) | ||
|
||
|
||
@pytest.mark.torch | ||
def test_protbert_save_reload(tmpdir): | ||
model_path = 'Rostlab/prot_bert' | ||
model = ProtBERT(task='classification', | ||
HG_model_path=model_path, | ||
n_tasks=1, | ||
cls_name="FFN", | ||
model_dir=tmpdir) | ||
model._ensure_built() | ||
model.save_checkpoint() | ||
|
||
model_new = ProtBERT(task='classification', | ||
HG_model_path=model_path, | ||
n_tasks=1, | ||
cls_name="FFN", | ||
model_dir=tmpdir) | ||
model_new.restore() | ||
|
||
old_state = model.model.state_dict() | ||
new_state = model_new.model.state_dict() | ||
matches = [ | ||
torch.allclose(old_state[key], new_state[key]) | ||
for key in old_state.keys() | ||
] | ||
|
||
# all keys values should match | ||
assert all(matches) | ||
|
||
|
||
@pytest.mark.torch | ||
def test_protbert_overfit(): | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
featurizer = dc.feat.DummyFeaturizer() | ||
tasks = ["outcome"] | ||
loader = dc.data.CSVLoader(tasks=tasks, | ||
feature_field="protein", | ||
featurizer=featurizer) | ||
dataset = loader.create_dataset( | ||
os.path.join(current_dir, | ||
"../../tests/assets/example_protein_classification.csv")) | ||
model_path = 'Rostlab/prot_bert' | ||
finetune_model = ProtBERT(task='classification', | ||
HG_model_path=model_path, | ||
n_tasks=1, | ||
cls_name="FFN", | ||
batch_size=1, | ||
learning_rate=1e-5) | ||
classification_metric = dc.metrics.Metric(dc.metrics.accuracy_score) | ||
finetune_model.fit(dataset, nb_epoch=10) | ||
eval_score = finetune_model.evaluate(dataset, [classification_metric]) | ||
assert eval_score[classification_metric.name] > 0.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters