From 94167c20e70563ace164d36e12b09e9eca6e7c70 Mon Sep 17 00:00:00 2001 From: Niklas Piet Doering Date: Wed, 5 Jun 2024 16:12:10 +0200 Subject: [PATCH] add flag for figure type selection default png --- .../openmmdl_analysis/barcode_generation.py | 10 +++---- .../markov_state_figure_generation.py | 6 ++--- .../openmmdl_analysis/openmmdlanalysis.py | 27 +++++++++++++------ .../openmmdl_analysis/rmsd_calculation.py | 8 +++--- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/openmmdl/openmmdl_analysis/barcode_generation.py b/openmmdl/openmmdl_analysis/barcode_generation.py index 2cb8dab0..6ad434a7 100644 --- a/openmmdl/openmmdl_analysis/barcode_generation.py +++ b/openmmdl/openmmdl_analysis/barcode_generation.py @@ -105,7 +105,7 @@ def plot_barcodes(barcodes, save_path): plt.savefig(f"./Barcodes/{save_path}", dpi=300, bbox_inches="tight") -def plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interactions): +def plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interactions, fig_type): """Generates piecharts for each waterbridge interaction with the water ids of the interacting waters. Args: @@ -179,13 +179,13 @@ def plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interact # Adjust the position of the subplots within the figure plt.subplots_adjust(top=0.99, bottom=0.01) # You can change the value as needed plt.savefig( - f"Barcodes/Waterbridge_Piecharts/{waterbridge_interaction}.svg", + f"Barcodes/Waterbridge_Piecharts/{waterbridge_interaction}.{fig_type}", bbox_inches="tight", dpi=300, ) -def plot_barcodes_grouped(interactions, df_all, interaction_type): +def plot_barcodes_grouped(interactions, df_all, interaction_type, fig_type): """generates barcode figures and groups them by ligandatom, aswell as total interaction barcode for a giveen lingenatom. Args: @@ -225,7 +225,7 @@ def plot_barcodes_grouped(interactions, df_all, interaction_type): os.makedirs(f"./Barcodes/{ligatom}", exist_ok=True) plot_barcodes( ligatom_interaction_barcodes, - f"{ligatom}/{ligatom}_{interaction_type}_barcodes.svg", + f"{ligatom}/{ligatom}_{interaction_type}_barcodes.{fig_type}", ) barcodes_list = list(ligatom_interaction_barcodes.values()) @@ -234,4 +234,4 @@ def plot_barcodes_grouped(interactions, df_all, interaction_type): grouped_array = grouped_array.astype(int) total_interactions[ligatom] = grouped_array - plot_barcodes(total_interactions, f"{interaction_type}_interactions.svg") + plot_barcodes(total_interactions, f"{interaction_type}_interactions.{fig_type}") diff --git a/openmmdl/openmmdl_analysis/markov_state_figure_generation.py b/openmmdl/openmmdl_analysis/markov_state_figure_generation.py index a6141089..615d26ca 100644 --- a/openmmdl/openmmdl_analysis/markov_state_figure_generation.py +++ b/openmmdl/openmmdl_analysis/markov_state_figure_generation.py @@ -24,7 +24,7 @@ def min_transition_calculation(min_transition): def binding_site_markov_network( - total_frames, min_transitions, combined_dict, font_size=36, size_node=200 + total_frames, min_transitions, combined_dict, fig_type, font_size=36, size_node=200 ): """Generate Markov Chain plots based on transition probabilities. @@ -317,8 +317,8 @@ def binding_site_markov_network( plt.axis("off") plt.tight_layout() - # Save the plot as a SVG file - plot_filename = f"markov_chain_plot_{min_transition_percent}.svg" + # Save the plot + plot_filename = f"markov_chain_plot_{min_transition_percent}.{fig_type}" plot_path = os.path.join("Binding_Modes_Markov_States", plot_filename) os.makedirs( "Binding_Modes_Markov_States", exist_ok=True diff --git a/openmmdl/openmmdl_analysis/openmmdlanalysis.py b/openmmdl/openmmdl_analysis/openmmdlanalysis.py index 0b5c7504..d8d83a95 100644 --- a/openmmdl/openmmdl_analysis/openmmdlanalysis.py +++ b/openmmdl/openmmdl_analysis/openmmdlanalysis.py @@ -205,6 +205,13 @@ def main(): help="Set the Eps for clustering, this defines how big clusters can be spatially in Angstrom", default=1.0, ) + + parser.add_argument( + "--figure", + dest="figure_type", + help="File type for the figures, default is png. Can be changed to all file types supported by matplotlib.", + default="png", + ) pdb_md = None input_formats = [ @@ -280,6 +287,7 @@ def main(): special_ligand = args.special_ligand reference = args.reference peptide = args.peptide + fig_type = args.figure_type generate_representative_frame = args.representative_frame @@ -364,33 +372,36 @@ def main(): rmsd_for_atomgroups( f"{topology}", f"{trajectory}", + fig_type, selection1="nucleicbackbone", selection2=["nucleic", f"resname {ligand}"], ) if frame_rmsd != "No": RMSD_dist_frames( - f"{topology}", f"{trajectory}", lig=f"{ligand}", nucleic=True + f"{topology}", f"{trajectory}", fig_type, lig=f"{ligand}", nucleic=True ) print("\033[1mRMSD calculated\033[0m") elif peptide != None: rmsd_for_atomgroups( f"{topology}", f"{trajectory}", + fig_type, selection1="backbone", selection2=["protein", f"chainID {peptide}"], ) if frame_rmsd != "No": - RMSD_dist_frames(f"{topology}", f"{trajectory}", lig=f"chainID {peptide}") + RMSD_dist_frames(f"{topology}", f"{trajectory}", fig_type, lig=f"chainID {peptide}") print("\033[1mRMSD calculated\033[0m") else: rmsd_for_atomgroups( f"{topology}", f"{trajectory}", + fig_type, selection1="backbone", selection2=["protein", f"resname {ligand}"], ) if frame_rmsd != "No": - RMSD_dist_frames(f"{topology}", f"{trajectory}", lig=f"{ligand}") + RMSD_dist_frames(f"{topology}", f"{trajectory}", fig_type, lig=f"{ligand}") print("\033[1mRMSD calculated\033[0m") if receptor_nucleic: @@ -538,7 +549,7 @@ def main(): # Generate Markov state figures of the binding modes total_frames = len(pdb_md.trajectory) - 1 min_transitions = min_transition_calculation(min_transition) - binding_site_markov_network(total_frames, min_transitions, combined_dict) + binding_site_markov_network(total_frames, min_transitions, combined_dict, fig_type) print("\033[1mMarkov State Figure generated\033[0m") # Get the top 10 nodes with the most occurrences @@ -692,7 +703,7 @@ def main(): # Convert the svg to an png cairosvg.svg2png( - url=f"{binding_mode}.svg", write_to=f"{binding_mode}.png" + url=f"{binding_mode}.{fig_type}", write_to=f"{binding_mode}.png" ) # Generate the interactions legend and combine it with the ligand png @@ -705,7 +716,7 @@ def main(): merged_image_paths, "all_binding_modes_arranged.png" ) generate_ligand_image( - ligand, "complex.pdb", "lig_no_h.pdb", "lig.smi", "ligand_numbering.svg" + ligand, "complex.pdb", "lig_no_h.pdb", "lig.smi", f"ligand_numbering.{fig_type}" ) print("\033[1mBinding mode figure generated\033[0m") except Exception as e: @@ -855,9 +866,9 @@ def main(): } for interaction_type, interaction_data in interaction_types.items(): - plot_barcodes_grouped(interaction_data, df_all, interaction_type) + plot_barcodes_grouped(interaction_data, df_all, interaction_type, fig_type) - plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interactions) + plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interactions, fig_type) print("\033[1mBarcodes generated\033[0m") interacting_water_id_list = interacting_water_ids(df_all, waterbridge_interactions) diff --git a/openmmdl/openmmdl_analysis/rmsd_calculation.py b/openmmdl/openmmdl_analysis/rmsd_calculation.py index 846be729..51ee76d6 100644 --- a/openmmdl/openmmdl_analysis/rmsd_calculation.py +++ b/openmmdl/openmmdl_analysis/rmsd_calculation.py @@ -10,7 +10,7 @@ def rmsd_for_atomgroups( - prot_lig_top_file, prot_lig_traj_file, selection1, selection2=None + prot_lig_top_file, prot_lig_traj_file, fig_type, selection1, selection2=None ): """Calulate the RMSD for selected atom groups, and save the csv file and plot. @@ -44,12 +44,12 @@ def rmsd_for_atomgroups( # Plot and save the RMSD over time as a PNG file rmsd_df.plot(title="RMSD of protein and ligand") plt.ylabel("RMSD (Å)") - plt.savefig("./RMSD/RMSD_over_time.svg") + plt.savefig(f"./RMSD/RMSD_over_time.{fig_type}") return rmsd_df -def RMSD_dist_frames(prot_lig_top_file, prot_lig_traj_file, lig, nucleic=False): +def RMSD_dist_frames(prot_lig_top_file, prot_lig_traj_file, fig_type, lig, nucleic=False): """Calculate the RMSD between all frames in a matrix. Args: @@ -96,5 +96,5 @@ def RMSD_dist_frames(prot_lig_top_file, prot_lig_traj_file, lig, nucleic=False): fig.colorbar(img1, ax=ax, orientation="horizontal", fraction=0.1, label="RMSD (Å)") - plt.savefig("./RMSD/RMSD_between_the_frames.svg") + plt.savefig(f"./RMSD/RMSD_between_the_frames.{fig_type}") return pairwise_rmsd_prot, pairwise_rmsd_lig