Skip to content

Commit

Permalink
update plot
Browse files Browse the repository at this point in the history
  • Loading branch information
anuprulez committed Aug 12, 2024
1 parent 9508d74 commit d2cc07b
Showing 1 changed file with 49 additions and 21 deletions.
70 changes: 49 additions & 21 deletions scripts/transformer_paper_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,13 @@ def plot_model_load_times_CPU_GPU():
cpu_load_times = pd.read_csv("../plots/transformer_rnn_runs_model_load_time_final_model_CPU.csv")
#cpu_load_times["compute_type"] = ["CPU", "CPU", "CPU", "CPU", "CPU"]

sns.barplot(data=gpu_load_times, x="l_tran", y="tran_load_time", label="", color="green", errorbar="sd", capsize=.2)
sns.barplot(data=gpu_load_times, x="l_rnn", y="rnn_load_time", label="", color="red", errorbar="sd", capsize=.2)
sns.barplot(data=gpu_load_times, x="l_cnn", y="cnn_load_time", label="", color="blue", errorbar="sd", capsize=.2)
sns.barplot(data=gpu_load_times, x="l_dnn", y="dnn_load_time", label="", color="black", errorbar="sd", capsize=.2)
#sns.barplot(data=gpu_load_times, x="l_tran", y="tran_load_time", label="", color="green", errorbar="sd", capsize=.2)
#sns.barplot(data=gpu_load_times, x="l_rnn", y="rnn_load_time", label="", color="red", errorbar="sd", capsize=.2)
#sns.barplot(data=gpu_load_times, x="l_cnn", y="cnn_load_time", label="", color="blue", errorbar="sd", capsize=.2)
#sns.barplot(data=gpu_load_times, x="l_dnn", y="dnn_load_time", label="", color="black", errorbar="sd", capsize=.2)

#sns.barplot(data=cpu_load_times, x="l_tran", y="tran_load_time", label="", linestyle="-", color="green")
#sns.barplot(data=cpu_load_times, x="l_rnn", y="rnn_load_time", label="", linestyle="-", color="red")
sns.barplot(data=cpu_load_times, x="l_tran", y="tran_load_time", label="", linestyle="-", color="green")
sns.barplot(data=cpu_load_times, x="l_rnn", y="rnn_load_time", label="", linestyle="-", color="red")
#sns.barplot(data=cpu_load_times, x="l_cnn", y="cnn_load_time", label="", linestyle="-", color="blue")
#sns.barplot(data=cpu_load_times, x="l_dnn", y="dnn_load_time", label="", linestyle="-", color="black")

Expand Down Expand Up @@ -634,30 +634,35 @@ def plot_usage_time_vs_seq_len():
plt.savefig("../plots/transformer_rnn_runs_model_pred_time_seq_length.png", dpi=dpi)


def make_scatter_beyond_training():
def make_bar_beyond_training():

font = {'family': 'serif', 'size': 18}
fig_size = (12, 6)
#font = {'family': 'serif', 'size': 18}
#fig_size = (6, 6)
#fig = plt.figure(figsize=fig_size)
plt.rc('font', **font)
#plt.rc('font', **font)

dpi = 300
analysis = "Single-cell"
input_tool = "anndata_import"
ground_truth = ["scanpy_filter", "anndata_inspect", "anndata_manipulate", "ucsc_cell_browser", "scanpy_inspect", "scanpy_filter_cells"]
#dpi = 300

#ground_truth = ["scanpy_filter", "anndata_inspect", "anndata_manipulate", "ucsc_cell_browser", "scanpy_inspect", "scanpy_filter_cells"]

pred_transformer_gt = ground_truth
pred_rnn_gt = ground_truth
#pred_transformer_gt = ground_truth
#pred_rnn_gt = ground_truth

pred_transformer_b_training = ["scanpy_normalise_data", "scanpy_plot", "anndata_ops", "scanpy_remove_confounders", "scanpy_integrate_harmony", "scanpy_normalize", "scpred_get_feature_space", "scanpy_find_variable_genes", "scpred_predict_labels", "scpred_eigen_decompose"]
#pred_transformer_b_training = ["scanpy_normalise_data", "scanpy_plot", "anndata_ops", "scanpy_remove_confounders", "scanpy_integrate_harmony", "scanpy_normalize", "scpred_get_feature_space", "scanpy_find_variable_genes", "scpred_predict_labels", "scpred_eigen_decompose"]

pred_rnn_b_training = ["scanpy_plot", "scanpy_normalise_data", "scmap_scmap_cluster", "scmap_scmap_cell", "scanpy_filter_genes"]
#pred_rnn_b_training = ["scanpy_plot", "scanpy_normalise_data", "scmap_scmap_cluster", "scmap_scmap_cell", "scanpy_filter_genes"]

xlabels = ["Transformer", "RNN"]

xtypes = ["Transformer", "RNN"]
#analysis = "Proteomics" #"Single-cell"
#input_tool = "Proteomics*" #"anndata_import"

matrix = [len(pred_transformer_b_training), len(pred_rnn_b_training)]
matrix = [3, 1]
# single cell: (anndata_import) [10, 5]
# deep learning: "keras_train_and_eval" [3, 1]
# variant calling: "snpeff_sars_cov_2" [5, 3]
# proteomics: massspectrometryimagingfiltering, cardinalpreprocessing, cardinalsegmentations [12, 0]

df_recommendations = pd.DataFrame(zip(xlabels, matrix, xtypes), columns=["xlabels", "recommendations", "model_types"])

Expand All @@ -671,12 +676,33 @@ def make_scatter_beyond_training():
ax.set_xticks(xlabels)
plt.xlabel("Model types")
plt.ylabel("Number of recommended tools")
plt.title("Anndata_import: Generalisation")
plt.title("Snpeff_sars_cov_2: Generalisation")
plt.yticks([0, 2, 4, 6, 8, 10, 12, 14])
plt.ylim((0, 15))
plt.tight_layout()
plt.savefig("../plots/transformer_rnn_beyond_workflows.pdf", dpi=dpi)
plt.savefig("../plots/transformer_rnn_beyond_workflows.png", dpi=dpi)


def create_test_precision_plot():
df_prec = pd.read_csv("../plots/df_tr_rnn_cnn_dnn_runs_te_prec.csv", sep="\t")
print(df_prec)

sns.lineplot(data=df_prec, x="indices", y="tran_prec", label="Transformer", color="green", linestyle="-")
sns.lineplot(data=df_prec, x="indices", y="rnn_prec", label="RNN", color="red", linestyle="-")
#sns.lineplot(data=df_prec, x="indices", y="cnn_prec", label="CNN", color="blue", linestyle="-")
#sns.lineplot(data=df_prec, x="indices", y="dnn_prec", label="DNN", color="black", linestyle="-")

plt.grid(True)
plt.xlabel("Training iterations")
plt.ylabel("Precision@k")
plt.title("Precision@k of models for test data")

plt.savefig("../plots/rnn_cnn_dnn_runs_te_prec_defense.pdf", dpi=dpi, bbox_inches='tight')
plt.savefig("../plots/rnn_cnn_dnn_runs_te_prec_defense.png", dpi=dpi, bbox_inches='tight')
#plt.show()


############ Call methods ###########################

#collect_loss_prec_data(["transformer", "rnn", "cnn", "dnn"])
Expand All @@ -685,4 +711,6 @@ def make_scatter_beyond_training():
#plot_model_load_times_CPU_GPU()
#plot_usage_time_vs_topk()
#plot_usage_time_vs_seq_len()
make_scatter_beyond_training()
make_bar_beyond_training()
#create_test_precision_plot()
#create_ground_truth_beyond_workflows_recommendations_plot()

0 comments on commit d2cc07b

Please sign in to comment.