-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun_af2_step2.py
152 lines (132 loc) · 5.55 KB
/
run_af2_step2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python
#-*- coding:utf-8 -*-
#
#
# AlphaFold2 Step 2 -- Run models 1-5 to produce the unrelaxed models
# Usage: run_af2_step2.py [--data_dir AF2_db_dir] /path/to/features.pkl /path/to/output_dir
#
#
import json
import os
import pathlib
import pickle
import random
import shutil
import sys
import time
import gzip
from typing import Dict, Union, Optional
import configparser
import argparse
import inspect
os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '4.0'
cur_path = pathlib.Path(__file__).parent.resolve()
ini_config = configparser.ConfigParser(allow_no_value=True)
assert len(ini_config.read(os.path.join(cur_path, 'config.ini'))) > 0, "Read config.ini failed"
sys.path.insert(0, ini_config['ALPHAFOLD2']['alphafold_path'])
from alphafold.common import protein
from alphafold.common import residue_constants
from alphafold.data import pipeline
from alphafold.model import config
from alphafold.model import model
from alphafold.relax import relax
from alphafold.model import data
import numpy as np
parser = argparse.ArgumentParser(description='AlphaFold2 Step 2 -- Run models 1-5 to produce the unrelaxed models')
parser.add_argument('input_file', metavar='input_file', type=str, help='The features.pkl file generated by AlphaFold2 step 1')
parser.add_argument('output_dir', metavar='output_dir', type=str, help='Path to a directory that will store the results.')
parser.add_argument('--params_parent_dir', default=ini_config['DATABASE']['params_parent_dir'], type=str, help="Path to the AlphaFold database, must contain params path")
parser.add_argument('--models', default='1,2,3,4,5', type=str, help="Models to run, seperated by comma. (1,2,3 or 1_ptm,2_ptm)")
parser.add_argument('--num_recycle', default=3, type=int, help="Number of recycles")
parser.add_argument('--num_ensemble', default=1, type=int, help="Number of ensembl")
args = parser.parse_args()
######################
## Util functions
######################
def func_has_agu(func, agu):
param_keys = list(inspect.signature(func).parameters.keys())
return agu in param_keys
model_names = [f'model_{i}' for i in range(1, 6)]
def get_model_runner(i, ptm=False):
model_name = model_names[i]
if ptm:
model_name += "_ptm"
model_config = config.model_config(model_name)
model_config.data.eval.num_ensemble = args.num_ensemble
model_config.model.num_recycle = args.num_recycle
#if args.low_memory:
# model_config.model.global_config.subbatch_size = 1 # To save memory
model_params = data.get_model_haiku_params(model_name=model_name, data_dir=args.params_parent_dir)
model_runner = model.RunModel(model_config, model_params)
return model_runner, model_params
######################
## Read features.pkl file
######################
if args.input_file.endswith('.gz'):
feature_dict = pickle.load(gzip.open(args.input_file, 'rb'))
else:
feature_dict = pickle.load(open(args.input_file, 'rb'))
print("Input length:", feature_dict['aatype'].shape[0], flush=True)
######################
## Run model 1-5 sperately
######################
output_dir = args.output_dir
assert os.path.exists(output_dir), "Error: --output_dir does not exists"
models_to_run = []
models_are_ptm = []
for item in args.models.split(','):
if '_' in item:
id_, ptm_ = item.split('_')
else:
id_ = item
ptm_ = ''
id_ = int(id_) - 1
models_to_run.append(int(id_))
if ptm_ == 'ptm':
models_are_ptm.append(True)
else:
models_are_ptm.append(False)
assert 0 <= int(id_) <= 4, "Error: --models should be 1<=model<=5"
for i,ptm in zip(models_to_run, models_are_ptm):
ptm_token = '_ptm' if ptm else ''
unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_names[i]}{ptm_token}.pdb')
result_output_path = os.path.join(output_dir, f'result_{model_names[i]}{ptm_token}.pkl.gz')
if os.path.exists(unrelaxed_pdb_path) and os.path.exists(result_output_path):
print(f"Info: {unrelaxed_pdb_path} and {result_output_path} exists, please delete and try again", flush=True)
continue
print(f"Start to run model_{i+1}{ptm_token}", flush=True)
###########################
### Get the Runner
############################
model_runner, model_params = get_model_runner(i, ptm)
processed_feature_dict = model_runner.process_features(feature_dict, random_seed=None)
if func_has_agu(model_runner.predict, 'random_seed'):
prediction_result = model_runner.predict(processed_feature_dict, random_seed=0)
else:
prediction_result = model_runner.predict(processed_feature_dict)
# Save memory
del model_runner
del model_params
###########################
### Save as Protein object
############################
plddt = prediction_result['plddt']
plddt_b_factors = np.repeat(plddt[:, None], residue_constants.atom_type_num, axis=-1)
params = {
'features': processed_feature_dict,
'result': prediction_result,
'b_factors': plddt_b_factors,
}
if func_has_agu(protein.from_prediction, 'remove_leading_feature_dimension'):
params['remove_leading_feature_dimension'] = True
unrelaxed_protein = protein.from_prediction(**params)
###########################
### Save as PDB file
###########################
unrelaxed_pdb = protein.to_pdb(unrelaxed_protein)
print(unrelaxed_pdb, file=open(unrelaxed_pdb_path, 'w'))
###########################
### Save as pkl file
###########################
pickle.dump(prediction_result, gzip.open(result_output_path, 'wb'), protocol=4)