Skip to content

Commit

Permalink
update plot
Browse files Browse the repository at this point in the history
  • Loading branch information
anuprulez committed Aug 13, 2024
1 parent d2cc07b commit a7d6254
Showing 1 changed file with 100 additions and 3 deletions.
103 changes: 100 additions & 3 deletions scripts/transformer_paper_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ def plot_model_load_times_CPU_GPU():
#sns.barplot(data=gpu_load_times, x="l_cnn", y="cnn_load_time", label="", color="blue", errorbar="sd", capsize=.2)
#sns.barplot(data=gpu_load_times, x="l_dnn", y="dnn_load_time", label="", color="black", errorbar="sd", capsize=.2)

sns.barplot(data=cpu_load_times, x="l_tran", y="tran_load_time", label="", linestyle="-", color="green")
sns.barplot(data=cpu_load_times, x="l_rnn", y="rnn_load_time", label="", linestyle="-", color="red")
sns.barplot(data=cpu_load_times, x="l_tran", y="tran_load_time", label="", linestyle="-", color="green", edgecolor = 'black', linewidth = 2)
sns.barplot(data=cpu_load_times, x="l_rnn", y="rnn_load_time", label="", linestyle="-", color="red", edgecolor = 'black', linewidth = 2)
#sns.barplot(data=cpu_load_times, x="l_cnn", y="cnn_load_time", label="", linestyle="-", color="blue")
#sns.barplot(data=cpu_load_times, x="l_dnn", y="dnn_load_time", label="", linestyle="-", color="black")

Expand Down Expand Up @@ -702,6 +702,100 @@ def create_test_precision_plot():
plt.savefig("../plots/rnn_cnn_dnn_runs_te_prec_defense.png", dpi=dpi, bbox_inches='tight')
#plt.show()

def make_grouped_stacked_beyond_training():
font = {'family': 'serif', 'size': 18}
fig_size = (12, 12)
fig = plt.figure(figsize=fig_size)
plt.rc('font', **font)
dpi=300
data = pd.DataFrame({
'Group': ['Deep learning', 'Deep learning', 'Single-cell', \
'Single-cell', 'Variant calling', 'Variant calling', 'Proteomics', 'Proteomics'],
#'Group': ['Keras_train_and_eval', 'Keras_train_and_eval', 'Anndata_import', \
# 'Anndata_import', 'Snpeff_sars_cov_2', 'Snpeff_sars_cov_2', 'Proteomics*', 'Proteomics*'],
'Model types': ['Transformer', 'RNN', 'Transformer', 'RNN', 'Transformer', 'RNN', 'Transformer', 'RNN'],
'Value': [3, 1, 10, 5, 5, 3, 12, 0]
})

# single cell: (anndata_import) [10, 5]
# deep learning: "keras_train_and_eval" [3, 1]
# variant calling: "snpeff_sars_cov_2" [5, 3]
# proteomics: massspectrometryimagingfiltering, cardinalpreprocessing, cardinalsegmentations [12, 0]
# Pivot the data to have categories as columns
pivot_data = data.pivot(index='Group', columns='Model types', values='Value')
palette = {'Transformer': "green", 'RNN': "red"}
# Create the stacked bar plot
pivot_data.plot(kind='barh', color=palette, edgecolor = 'black', linewidth = 2) #, stacked=True

plt.grid(True)
plt.xlabel("Number of tools")
plt.ylabel("Types of analyses")
plt.title("Generalisation")
#plt.yticks([0, 2, 4, 6, 8, 10, 12, 14])
#plt.ylim((0, 40))
#plt.xticks(rotation=45)
#plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
plt.tight_layout()
plt.savefig("../plots/transformer_rnn_stacked_beyond_workflows.pdf", dpi=dpi, bbox_inches='tight')
plt.savefig("../plots/transformer_rnn_stacked_beyond_workflows.png", dpi=dpi, bbox_inches='tight')
plt.show()


def make_age_prediction_scatter_plot():
#font = {'family': 'serif', 'size': 18}
#fig_size = (12, 12)
#fig = plt.figure(figsize=fig_size)
#plt.rc('font', **font)
fig = plt.figure()
dpi=300

true = pd.read_csv("../plots/test_rows_labels.tabular", sep="\t")["Age"]

predicted = pd.read_csv("../plots/predicted_age.tabular", sep="\t")["predicted"]

df_age = pd.DataFrame(zip(true, predicted), columns=["true", "predicted"])

sns.scatterplot(data=df_age, x="true", y="predicted", color="red")

# Get the range of values for the x and y axis
min_val = min(df_age["true"].min(), df_age["predicted"].min())
max_val = max(df_age["true"].max(), df_age["predicted"].max())

# Plot the y=x line
plt.plot([min_val, max_val], [min_val, max_val], color='black', linestyle='--')

plt.grid(True)
plt.xlabel("True age")
plt.ylabel("Predicted age")
plt.title("True vs predicted age (RMSE: 3.76, R2: 0.94)")
#plt.yticks([0, 2, 4, 6, 8, 10, 12, 14])
#plt.ylim((0, 40))
#plt.xticks(rotation=45)
#plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
plt.tight_layout()
plt.savefig("../plots/true_vs_predicted_age.pdf", dpi=dpi, bbox_inches='tight')
plt.savefig("../plots/true_vs_predicted_age.png", dpi=dpi, bbox_inches='tight')
plt.show()


def plot_model_training_times_CPU():
print("")

model_names = ["Transformer", "RNN", "CNN", "DNN"]
model_times = [44, 246, 43, 14]
df_model_train_time = pd.DataFrame(zip(model_names, model_times), columns=["model_names", "model_times"])

palette = {'Transformer': "green", 'RNN': "red", "CNN": "blue", "DNN": "black"}
sns.barplot(data=df_model_train_time, x="model_names", y="model_times", palette=palette, edgecolor = 'black', linewidth = 2)


plt.grid(True)
plt.xlabel("Model types")
plt.ylabel("Time (seconds)")
plt.title("Models vs their training time")
plt.savefig("../plots/transformer_rnn_runs_model_train_time_CPU.pdf", dpi=dpi, bbox_inches='tight')
plt.savefig("../plots/transformer_rnn_runs_model_train_time_CPU.png", dpi=dpi, bbox_inches='tight')
plt.show()

############ Call methods ###########################

Expand All @@ -711,6 +805,9 @@ def create_test_precision_plot():
#plot_model_load_times_CPU_GPU()
#plot_usage_time_vs_topk()
#plot_usage_time_vs_seq_len()
make_bar_beyond_training()
#make_bar_beyond_training()
make_grouped_stacked_beyond_training()
#make_age_prediction_scatter_plot()
#plot_model_training_times_CPU()
#create_test_precision_plot()
#create_ground_truth_beyond_workflows_recommendations_plot()

0 comments on commit a7d6254

Please sign in to comment.