-
Notifications
You must be signed in to change notification settings - Fork 5
/
filter.py
219 lines (197 loc) · 9.73 KB
/
filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import argparse
import pandas as pd
from src.util.pdb import get_pdb_numbering_from_residue_indices
from src.hallucination.utils.util\
import get_indices_from_different_methods,\
comma_separated_chain_indices_to_dict
from src.hallucination.utils.interfacemetrics_plotting_utils \
import iam_score_df_from_pdbs, scatter_hist,\
select_best_designs_by_sum
from src.hallucination.utils.sequence_utils import sequences_to_logo_without_weblogo
from src.hallucination.utils.rmsd_plotting_utils import \
threshold_by_rmsd_filters, write_fastas_for_alphafold2,\
plt_ff_publication_for_run
def output_filtered_designs(csv_dg, csv_rmsd,
target_pdb,
indices_hal=[],
rmsd_filter='H3,1.8',
rmsd_filter_json='',
outdir='.',
suffix='DeepAb'
):
os.makedirs(outdir, exist_ok=True)
df_dg = pd.read_csv(csv_dg, delimiter=',')
df_dg['design_id'] = \
[int(os.path.basename(t).split('.pdb')[0].split('_')[-2])
for t in list(df_dg['filename'])]
df_ff = pd.read_csv(csv_rmsd)
if rmsd_filter != '':
x = rmsd_filter.split(',')[0]
outfile_png = os.path.join(outdir, 'histrmsdff-{}.png'.format(suffix))
plt_ff_publication_for_run(csv_rmsd, x=x, outfile=outfile_png)
outfile = os.path.join(
outdir, 'df_ff-{}_thresholded_{{}}.csv'.format(suffix))
df_ff_thr, rmsd_suffix = threshold_by_rmsd_filters(df_ff, rmsd_filter=rmsd_filter,
rmsd_filter_json=rmsd_filter_json,
outfile=outfile)
df_dg_ff_thr = pd.merge(df_dg, df_ff_thr, on=[
'design_id'], suffixes=['', '_ff'])
outfile = os.path.join(
outdir, 'df_ff-{}_dg_thresholded_{}.csv'.format(suffix, rmsd_suffix))
df_dg_ff_thr.to_csv(outfile)
outfile_png = os.path.join(
outdir, 'df_ff-{}_thresholded_{}.png'.format(suffix, rmsd_suffix))
if rmsd_filter != '':
x = rmsd_filter.split(',')[0]
outfile_png = os.path.join(
outdir, 'histrmsdff-{}_thresholded_{}.png'.format(suffix, rmsd_suffix))
plt_ff_publication_for_run(outfile.format(
rmsd_suffix), x=x, outfile=outfile_png)
sequences_thresholded = list(df_dg_ff_thr['seq'])
print('{} sequences meet the thresholds.'.format(len(sequences_thresholded)))
if len(sequences_thresholded) > 0:
dict_residues = {'reslist': indices_hal}
labellist = \
get_pdb_numbering_from_residue_indices(target_pdb, indices_hal)
dict_residues.update({'labellist': labellist})
outfile_logo = \
os.path.join(
outdir, 'logo_ff-{}_dg_thresholded_rmsd{}.png'.format(suffix, rmsd_suffix))
sequences_to_logo_without_weblogo(sequences_thresholded, dict_residues=dict_residues,
outfile_logo=outfile_logo)
# write inputs for running alphafold
outdir_af2 = os.path.join(
outdir, 'ff-{}_ddg_thresholded_rmsd{}'.format(suffix, rmsd_suffix))
os.makedirs(outdir_af2, exist_ok=True)
write_fastas_for_alphafold2(list(df_dg_ff_thr['filename']), outdir_af2)
# interface metrics
select_by = ['dG_separated']
design_pdbs = list(set(list(df_dg_ff_thr['filename'])))
df_iam_mutants = iam_score_df_from_pdbs(design_pdbs)
print('iam: ', df_iam_mutants)
df_iam_ref = iam_score_df_from_pdbs([target_pdb])
n_all = min(50, len(design_pdbs))
pdb_dir = os.path.join(outdir, 'interface_metrics_pdbs')
os.makedirs(pdb_dir, exist_ok=True)
best_decoys = select_best_designs_by_sum(df_iam_mutants, by=select_by,
n=n_all, pdb_dir=pdb_dir,
out_path=pdb_dir)
selected_decoys_dir = os.path.join(outdir, 'selected_decoys_iam')
os.makedirs(selected_decoys_dir, exist_ok=True)
outfile = os.path.join(selected_decoys_dir, "scatterplot_dgneg.png")
df_iam_mutants_neg = df_iam_mutants[df_iam_mutants['dG_separated'] < 0.0]
if 'dG_separated' in df_iam_ref.columns:
scatter_hist(df_iam_mutants_neg, ref=df_iam_ref,
out=outfile, highlight=best_decoys, by=select_by)
out_csv_iam = os.path.join(
outdir, 'df_ref_iam.csv'.format(suffix, rmsd_suffix))
df_iam_ref.to_csv(out_csv_iam)
else:
scatter_hist(df_iam_mutants_neg, out=outfile,
highlight=best_decoys, by=select_by)
df_combined = pd.merge(df_dg_ff_thr, df_iam_mutants, on=['filename'])
out_csv_iam = os.path.join(
outdir, 'df_ff-{}_dg_iam_thresholded_rmsd{}.csv'.format(suffix, rmsd_suffix))
df_combined.to_csv(out_csv_iam)
df_best_indices = df_iam_mutants.loc[best_decoys]
df_combined_best = pd.merge(
df_dg_ff_thr, df_best_indices, on=['filename'])
out_csv_iam = \
os.path.join(selected_decoys_dir,
'df_ff-{}_dg_thresholded_rmsd{}_bestdecoys.csv'.format(suffix, rmsd_suffix))
df_combined_best.to_csv(out_csv_iam)
sequences_iam = list(df_combined_best['seq'])
outfile_logo = os.path.join(outdir,
'logo_ff-{}_dg_thresholded_rmsd{}_iam-top{}.png'.format(suffix, rmsd_suffix, n_all))
sequences_to_logo_without_weblogo(sequences_iam, dict_residues=dict_residues,
outfile_logo=outfile_logo)
def get_args():
desc = ('''
Filter designs that meet RMSD threshold and have improved binding energies.
Usage: python3 filter.py <options>
''')
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('target_pdb',
type=str,
help='path to target structure chothia numbered pdb file.\
Provide pdb for the target structure of the Fv with the antigen.')
parser.add_argument('--rmsd_filter',
default='H3,1.8',
help='specify metric and threshold separated by a comma.\
Metric list: OCD, H1, H2, H3, L1, L2, L3, HFr, LFr'
)
parser.add_argument('--rmsd_filter_json',
default='',
help='specify multiple metrics and threshold as a json dictionary.\
Metric list: OCD, H1, H2, H3, L1, L2, L3, HFr, LFr'
)
parser.add_argument('--csv_forward_folded',
default='',
help='csv file generated by --plot_consolidated_funnels'
)
parser.add_argument('--csv_complexes',
default='',
help='csv file generated by --plot_consolidated_dG'
)
parser.add_argument('--model_suffix',
default='DeepAb',
help='give optional suffix with --output_filtered_designs\
if --csv_forward_folded file was generated with a different model'
)
parser.add_argument('--outdir',
type=str,
default='./',
help='path to sequences dir')
parser.add_argument('--cdr_list',
type=str,
default='',
help='comma separated list of cdrs: l1,h2')
parser.add_argument('--framework',
action='store_true',
default=False,
help='design framework residues. Default: false')
parser.add_argument('--indices',
type=str,
default='',
help='comma separated list of chothia numbered residues to design: h:12,20,31A/l:56,57')
parser.add_argument('--exclude',
type=str,
default='',
help='comma separated list of chothia numbered residues to exclude from design: h:31A,52,53/l:97,99')
parser.add_argument('--hl_interface',
action='store_true',
default=False,
help='hallucinate hl interface')
return parser.parse_args()
def get_hal_indices(args):
dict_indices = {}
dict_exclude = {}
if args.indices != '':
indices_str = args.indices
print(indices_str)
dict_indices = comma_separated_chain_indices_to_dict(indices_str)
if args.exclude != '':
indices_str = args.exclude
dict_exclude = comma_separated_chain_indices_to_dict(indices_str)
indices_hal = get_indices_from_different_methods(
args.target_pdb,
cdr_list=args.cdr_list,
framework=args.framework,
hl_interface=args.hl_interface,
include_indices=dict_indices,
exclude_indices=dict_exclude)
print("Indices hallucinated: ", indices_hal)
return indices_hal
if __name__ == '__main__':
args = get_args()
indices_hal = get_hal_indices(args)
output_filtered_designs(args.csv_complexes,
args.csv_forward_folded,
args.target_pdb,
rmsd_filter=args.rmsd_filter,
rmsd_filter_json=args.rmsd_filter_json,
indices_hal=indices_hal,
outdir=args.outdir,
suffix=args.model_suffix
)