Skip to content

Commit

Permalink
generate tsv arg
Browse files Browse the repository at this point in the history
  • Loading branch information
minh-biocommons committed Sep 26, 2024
1 parent 5251c21 commit 56cc23b
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions bin/generat_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
#from Bio import PDB

def generate_output_images(msa_path, plddt_paths, name, out_dir, in_type):
def generate_output_images(msa_path, plddt_data, name, out_dir, in_type, generate_tsv):
msa = []
if not msa_path.endswith("NO_FILE"):
with open(msa_path, 'r') as in_file:
Expand Down Expand Up @@ -65,20 +65,16 @@ def generate_output_images(msa_path, plddt_paths, name, out_dir, in_type):
# ##################################################################

plddt_per_model = OrderedDict()
plddt_paths_srt = plddt_paths
plddt_paths_srt.sort()
for plddt_path in plddt_paths_srt:
with open(plddt_path, 'r') as in_file:
if in_type == "ESM-FOLD":
plddt_per_model[os.path.basename(plddt_path)[:-4]] = []
in_file.readline()
for line in in_file:
vals = line.strip().split()
#print(vals)
if len(vals) == 5:
plddt_per_model[os.path.basename(plddt_path)[:-4]].append(float(vals[-1].strip()))
else:
output_data = plddt_data

if generate_tsv == "y":
for plddt_path in output_data:
with open(plddt_path, 'r') as in_file:
plddt_per_model[os.path.basename(plddt_path)[:-4]] = [float(x) for x in in_file.read().strip().split()]
else:
for i, plddt_values_str in enumerate(output_data):
plddt_per_model[i] = []
plddt_per_model[i] = [float(x) for x in plddt_values_str.strip().split()]

# plt.figure(figsize=(14, 14), dpi=100)
# plt.title("Predicted LDDT per position")
Expand Down Expand Up @@ -261,14 +257,14 @@ def align_structures(structures):
return aligned_structures
"""

def pdb_to_lddt(pdb_files):
output_files = []
averages = []
def pdb_to_lddt(pdb_files, generate_tsv):
pdb_files_sorted = pdb_files
pdb_files_sorted.sort()

for pdb_file in pdb_files:
output_file = f"{pdb_file.replace('.pdb', '')}_plddt.tsv"
output_files.append(output_file)
output_lddt = []
averages = []

for pdb_file in pdb_files_sorted:
plddt_values = []
seen_lines = set()

Expand All @@ -288,14 +284,21 @@ def pdb_to_lddt(pdb_files):
else:
averages.append(0.0)

with open(output_file, 'w') as outfile:
outfile.write("\t".join(map(str, plddt_values)) + "\n")
if generate_tsv == "y":
output_file = f"{pdb_file.replace('.pdb', '')}_plddt.tsv"
with open(output_file, 'w') as outfile:
outfile.write(" ".join(map(str, plddt_values)) + "\n")
output_lddt.append(output_file)
else:
plddt_values_string = " ".join(map(str, plddt_values))
output_lddt.append(plddt_values_string)

return output_files, averages
return output_lddt, averages

print("Starting..")
print("Starting...")
parser = argparse.ArgumentParser()
parser.add_argument('--type', dest='in_type')
parser.add_argument('--generate_tsv', choices=['y', 'n'], required=True, dest='generate_tsv')
parser.add_argument('--msa', dest='msa',required=True)
parser.add_argument('--pdb', dest='pdb',required=True, nargs="+")
parser.add_argument('--name', dest='name')
Expand All @@ -306,9 +309,9 @@ def pdb_to_lddt(pdb_files):
parser.set_defaults(name='')
args = parser.parse_args()

lddt_files, lddt_averages = pdb_to_lddt(args.pdb)
lddt_data, lddt_averages = pdb_to_lddt(args.pdb, args.generate_tsv)

generate_output_images(args.msa, lddt_files, args.name, args.output_dir, args.in_type)
generate_output_images(args.msa, lddt_data, args.name, args.output_dir, args.in_type, args.generate_tsv)
#generate_plots(args.msa, args.plddt, args.name, args.output_dir)

print("generating html report...")
Expand Down

0 comments on commit 56cc23b

Please sign in to comment.