-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_af_multimer_step2.py
137 lines (116 loc) · 4.95 KB
/
run_af_multimer_step2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python
#-*- coding:utf-8 -*-
###############################################
#
#
# Run AlphaFold-Multimer step by step
# (https://github.com/deepmind/alphafold)
# Author: Pan Li ([email protected])
# @ Shuimu BioScience
# https://www.shuimubio.com/
#
#
################################################
#
#
# AlphaFold-Multimer Step 2 -- Run models 1-5 to produce the unrelaxed models
# Usage: run_af_multimer_step2.py [--data_dir AF2_db_dir] /path/to/features.pkl /path/to/output_dir
#
#
import json
import os
import pathlib
import pickle
import random
import shutil
import sys
import time
import gzip
from typing import Dict, Union, Optional
import configparser
import numpy as np
import argparse
os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '4.0'
#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
cur_path = pathlib.Path(__file__).parent.resolve()
ini_config = configparser.ConfigParser(allow_no_value=True)
assert len(ini_config.read(os.path.join(cur_path, 'config.ini'))) > 0, "Read config.ini failed"
sys.path.insert(0, ini_config['ALPHAFOLD2']['alphafold_path'])
from alphafold.common import protein
from alphafold.common import residue_constants
from alphafold.data import pipeline
from alphafold.data import pipeline_multimer
from alphafold.model import config
from alphafold.model import model
from alphafold.model import data
parser = argparse.ArgumentParser(description='AlphaFold-Multimer Step 2 -- Run models 1-5 to produce the unrelaxed models')
parser.add_argument('input_file', metavar='input_file', type=str, help='The features.pkl file generated by AlphaFold-Multimer step 1')
parser.add_argument('output_dir', metavar='output_dir', type=str, help='Path to a directory that will store the results.')
parser.add_argument('--params_parent_dir', default=ini_config['DATABASE']['params_parent_dir'], type=str, help="Path to the AlphaFold database, must contain params path")
parser.add_argument('--models', default='1,2,3,4,5', type=str, help="Models to run, seperated by comma.")
parser.add_argument('--num_recycle', default=3, type=int, help="Number of recycles")
args = parser.parse_args()
######################
## Util functions
######################
model_names = [f'model_{i}_multimer' for i in range(1, 6)]
def get_model_runner(i):
model_config = config.model_config(model_names[i]+"_v2")
model_config.model.num_ensemble_eval = 1
model_config.model.num_recycle = args.num_recycle
#model_config.model.global_config.subbatch_size = 1 # To save memory
model_params = data.get_model_haiku_params(model_name=model_names[i]+"_v2", data_dir=args.params_parent_dir)
model_runner = model.RunModel(model_config, model_params)
return model_runner, model_params
######################
## Read features.pkl file
######################
if args.input_file.endswith('.gz'):
feature_dict = pickle.load(gzip.open(args.input_file, 'rb'))
else:
feature_dict = pickle.load(open(args.input_file, 'rb'))
print("Input length:", feature_dict['aatype'].shape[0], flush=True)
######################
## Run model 1-5 sperately
######################
output_dir = args.output_dir
assert os.path.exists(output_dir), f"Error: {output_dir} does not exists"
models_to_run = [ int(i)-1 for i in args.models.split(',') ]
for i in models_to_run:
assert 0 <= i <= 4, "Error: --models should be 1<=model<=5"
for i in models_to_run:
unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_names[i]}.pdb')
result_output_path = os.path.join(output_dir, f'result_{model_names[i]}.pkl.gz')
if os.path.exists(unrelaxed_pdb_path) and os.path.exists(result_output_path):
print(f"Info: {unrelaxed_pdb_path} and {result_output_path} exists, please delete and try again", flush=True)
continue
print(f"Start to run model_{i+1}", flush=True)
###########################
### Get the Runner
############################
model_runner, model_params = get_model_runner(i)
# Need not to process the features for AlphaFold-Multimer
prediction_result = model_runner.predict(feature_dict, random_seed=0)
# Save memory
del model_runner
del model_params
###########################
### Save as Protein object
############################
plddt = prediction_result['plddt']
plddt_b_factors = np.repeat(plddt[:, None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(
features=feature_dict,
result=prediction_result,
b_factors=plddt_b_factors,
remove_leading_feature_dimension=False)
###########################
### Save as PDB file
###########################
unrelaxed_pdb = protein.to_pdb(unrelaxed_protein)
print(unrelaxed_pdb, file=open(unrelaxed_pdb_path, 'w'))
###########################
### Save as pkl file
###########################
pickle.dump(prediction_result, gzip.open(result_output_path, 'wb'), protocol=4)