Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 12, 2024
1 parent dfcb3ae commit 8bc152d
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 195 deletions.
68 changes: 40 additions & 28 deletions missense_kinase_toolkit/ml/src/esm2/analysis.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
#!/usr/bin/env python

import os
import glob
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# os.chdir("/data1/tanseyw/projects/whitej/esm_km_atp/src")
from utils import (
save_csv2csv,
load_csv2dataset,
parse_stats_dataframes,
invert_zscore,
)
from utils import invert_zscore, load_csv2dataset, parse_stats_dataframes, save_csv2csv

path = "/data1/tanseyw/projects/whitej/esm_km_atp/"
list_files = glob.glob(os.path.join(path, "*/fold-*.csv"))
df = pd.read_csv(os.path.join(path, "assets/pkis2_km_atp.csv"))
# log10 transform before z-scoring
labels = df['ATP Conc.(uM)'].apply(np.log10)
labels = df["ATP Conc.(uM)"].apply(np.log10)

list_runs = [
"5CV-KinCore-esm2_t6_8M_UR50D",
"5CV-KLIFS_MIN-esm2_t6_8M_UR50D",
"5CV-KLIFS_FULL-esm2_t6_8M_UR50D",
]

list_logs = [glob.glob(os.path.join(path, run, "*/logs/fold-*.csv")) for run in list_runs]
list_logs = [
glob.glob(os.path.join(path, run, "*/logs/fold-*.csv")) for run in list_runs
]

dict_logs = dict(zip(list_runs, list_logs))

### Load and process training and evaluation loss ###

df_train, df_eval, df_final = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
df_train, df_eval, df_final = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
for exp, list_file in dict_logs.items():
for idx, file in enumerate(list_file):
df_train_temp, df_eval_temp, df_final_temp = parse_stats_dataframes(file, idx)
Expand All @@ -45,7 +44,7 @@

### Load validation and training datasets ###

#TODO: Look at all files
# TODO: Look at all files
ds_val, ds_train = load_csv2dataset(path, 5, "KLIFS_FULL_data.csv")
df_val = pd.DataFrame()
for idx, ds in enumerate(ds_val):
Expand All @@ -58,11 +57,13 @@
list_fold = df_train["fold"].unique().tolist()
list_replace = [f"Fold: {i}\n(n = {sum(df_val["fold"] != i)})" for i in list_fold]
df_train["fold_label"] = df_train["fold"].map(dict(zip(list_fold, list_replace)))
g = sns.FacetGrid(df_train, col="fold_label", row="exp", hue="exp", sharey=True, sharex=True)
g = sns.FacetGrid(
df_train, col="fold_label", row="exp", hue="exp", sharey=True, sharex=True
)
g.map(sns.lineplot, "step", "loss")
g.add_legend()
g.set_axis_labels("Steps", "Training Loss")
g.set_titles('{col_name}')
g.set_titles("{col_name}")
plt.savefig(os.path.join(path, "images/train_loss_2024.10.30.png"))

# list_fold = df_train["fold"].unique().tolist()
Expand All @@ -80,7 +81,7 @@
list_replace = [f"Fold: {i}\n(n = {sum(df_val["fold"] == i)})" for i in list_fold]
df_eval["fold_label"] = df_eval["fold"].map(dict(zip(list_fold, list_replace)))
df_eval["log_rmse"] = invert_zscore(df_eval["eval_rmse"], labels)
df_eval["orig_rmse"] = df_eval["log_rmse"].apply(lambda x: 10 ** x)
df_eval["orig_rmse"] = df_eval["log_rmse"].apply(lambda x: 10**x)

# list_fold = df_eval["fold"].unique().tolist()
# list_replace = [f"Fold: {i}\n(n = {sum(df_val["fold"] == i)})" for i in list_fold]
Expand All @@ -91,14 +92,16 @@
# Leave RMSE in units of z-score log10(Km, ATP)

sns.set(font_scale=1.5)
df_eval["exp_label"] = df_eval["exp"].map(dict(zip(list_runs, ["KinCore KD", "KLIFS Pocket", "KLIFS Full Region"])))
df_eval["exp_label"] = df_eval["exp"].map(
dict(zip(list_runs, ["KinCore KD", "KLIFS Pocket", "KLIFS Full Region"]))
)
g = sns.FacetGrid(df_eval, col="fold_label", hue="exp_label", sharey=True, sharex=True)
# for ax in g.axes.flat:
# ax.axvline(500, color='r', linestyle='dashed', linewidth=1)
g.grid(False)
g.map(sns.lineplot, "step", "log_rmse")
g.set_axis_labels("Steps", "Held-Out RMSE\n" + r"$(log_{10} K_{M, ATP})$")
g.set_titles('{col_name}')
g.set_titles("{col_name}")
g.add_legend(title="Input sequence")
# g.figsize(8, 6)
plt.savefig(os.path.join(path, "images/eval_rmse_unconverted_2024.10.30.png"))
Expand All @@ -115,10 +118,10 @@

g = sns.FacetGrid(df_eval, col="fold_label", hue="fold", sharey=False, sharex=False)
for ax in g.axes.flat:
ax.axvline(500, color='r', linestyle='dashed', linewidth=1)
ax.axvline(500, color="r", linestyle="dashed", linewidth=1)
g.map(sns.lineplot, "step", "orig_rmse")
g.set_axis_labels("Steps", "RMSE, Eval (Converted)")
g.set_titles('{col_name}')
g.set_titles("{col_name}")
plt.savefig(os.path.join(path, "images/eval_rmse_converted.png"))

### Plot histogram of labels for validation set ###
Expand All @@ -127,7 +130,7 @@
list_replace = [f"Fold: {i}\n(n = {sum(df_val["fold"] == i)})" for i in list_fold]
df_val["fold_label"] = df_val["fold"].map(dict(zip(list_fold, list_replace)))
df_val["orig_label"] = invert_zscore(df_val["label"], labels)
df_val["orig_label"] = df_val["orig_label"].apply(lambda x: 10 ** x)
df_val["orig_label"] = df_val["orig_label"].apply(lambda x: 10**x)

# Leave labels in units of z-score log10(Km, ATP)

Expand All @@ -137,9 +140,14 @@
y, x, _ = plt.hist(df_val["label"])
for idx, ax in enumerate(g.axes.flat):
loc = df_val.loc[df_val["fold"] == idx + 1, "label"].mean()
ax.axvline(loc, color='r', linestyle='dashed', linewidth=1)
ax.text(loc + (x.max() - x.min()) * 0.1, y.max() * 0.9, "Mean: " + str(round(loc, 2)), color='r')
g.set_titles('{col_name}')
ax.axvline(loc, color="r", linestyle="dashed", linewidth=1)
ax.text(
loc + (x.max() - x.min()) * 0.1,
y.max() * 0.9,
"Mean: " + str(round(loc, 2)),
color="r",
)
g.set_titles("{col_name}")
plt.savefig(os.path.join(path, "images/val_label_hist_zscore.png"), bbox_inches="tight")

# Convert labels to original scale
Expand All @@ -150,16 +158,20 @@
y, x, _ = plt.hist(df_val["orig_label"])
for idx, ax in enumerate(g.axes.flat):
loc = df_val.loc[df_val["fold"] == idx + 1, "label"].mean()
ax.axvline(loc, color='r', linestyle='dashed', linewidth=1)
ax.text(loc + (x.max() - x.min()) * 0.1, y.max() * 0.9, "Mean: " + str(round(loc, 2)), color='r')
g.set_titles('{col_name}')
ax.axvline(loc, color="r", linestyle="dashed", linewidth=1)
ax.text(
loc + (x.max() - x.min()) * 0.1,
y.max() * 0.9,
"Mean: " + str(round(loc, 2)),
color="r",
)
g.set_titles("{col_name}")
plt.savefig(os.path.join(path, "images/val_label_hist_orig.png"), bbox_inches="tight")



# import numpy as np
# from utils import calc_zscore
# df = pd.read_csv(os.path.join(path, "assets/pkis2_km_atp.csv"))
# calc_zscore(df["ATP Conc.(uM)"].apply(np.log10))
# df.head()
# df["kd"].apply(len).max()
# df["kd"].apply(len).max()
2 changes: 1 addition & 1 deletion missense_kinase_toolkit/ml/src/esm2/batch_jobs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
while IFS=, read -r model col_seq run_name
do
sbatch -J ${run_name} run.sh ${model} ${col_seq} ${run_name}
done < batch_jobs.csv
done < batch_jobs.csv
45 changes: 30 additions & 15 deletions missense_kinase_toolkit/ml/src/esm2/inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from transformers import (
AutoTokenizer,
EsmForSequenceClassification
)
from transformers import AutoTokenizer, EsmForSequenceClassification

path_model = "/data1/tanseyw/projects/whitej/esm_km_atp/5CV-KLIFS_MIN-esm2_t6_8M_UR50D/full/results/checkpoint-12500"

Expand All @@ -15,7 +12,9 @@
model = EsmForSequenceClassification.from_pretrained(path_model).to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

df_klifs_zehir_muts_alphamissense = pd.read_csv("/data1/tanseyw/projects/whitej/esm_km_atp/assets/klifs_zehir_muts_alphamissense.csv")
df_klifs_zehir_muts_alphamissense = pd.read_csv(
"/data1/tanseyw/projects/whitej/esm_km_atp/assets/klifs_zehir_muts_alphamissense.csv"
)

list_outputs = []
for _, row in df_klifs_zehir_muts_alphamissense.iterrows():
Expand All @@ -25,26 +24,42 @@

dict_outputs = dict(zip(df_klifs_zehir_muts_alphamissense["hgnc_name"], list_outputs))

dict_muts = {i : None for i in dict_outputs.keys() if "_" in i}
dict_muts = {i: None for i in dict_outputs.keys() if "_" in i}
for key, value in dict_outputs.items():
if "_" in key:
wt = key.split("_")[0]
dict_muts[key] = (dict_outputs[key] - dict_outputs[wt]) / dict_outputs[wt]

df_klifs_zehir_muts_alphamissense["zscore_percent_change"] = df_klifs_zehir_muts_alphamissense["hgnc_name"].apply(lambda x: dict_muts[x] * 100 if x in dict_muts.keys() else None)
df_klifs_zehir_muts_alphamissense["zscore_percent_change_log"] = df_klifs_zehir_muts_alphamissense["zscore_percent_change"].apply(lambda x: np.sign(x) * np.log10(np.abs(x)))

df_klifs_zehir_muts_alphamissense["zscore_percent_change"] = (
df_klifs_zehir_muts_alphamissense["hgnc_name"].apply(
lambda x: dict_muts[x] * 100 if x in dict_muts.keys() else None
)
)
df_klifs_zehir_muts_alphamissense["zscore_percent_change_log"] = (
df_klifs_zehir_muts_alphamissense["zscore_percent_change"].apply(
lambda x: np.sign(x) * np.log10(np.abs(x))
)
)


sns.set(font_scale=2)
sns.set_style(style="white")
plt.figure(figsize=(20, 7))
# ax = sns.scatterplot(data = df_klifs_zehir_muts_alphamissense, x = "alphamissense_score", y = "zscore_percent_change", hue = "alphamissense_class")
ax = sns.scatterplot(data = df_klifs_zehir_muts_alphamissense, x = "alphamissense_score", y = "zscore_percent_change_log", hue = "alphamissense_class")
ax = sns.scatterplot(
data=df_klifs_zehir_muts_alphamissense,
x="alphamissense_score",
y="zscore_percent_change_log",
hue="alphamissense_class",
)
# plt.axhline(y=0, color='red', linestyle='--')
# plt.yscale('log')
plt.legend(title = "Alphamissense Class")
plt.legend(title="Alphamissense Class")
plt.xlabel("Alphamissense Score")
# plt.ylabel(" Predicted Z-score\n% Change vs. Wild-Type")
plt.ylabel(r"$log_{10}$" + " Predicted Z-score\n% Change vs. Wild-Type")
plt.savefig("/data1/tanseyw/projects/whitej/esm_km_atp/images/zscore_percent_change_vs_alphamissense_score_log.png", bbox_inches = "tight")
# plt.savefig("/data1/tanseyw/projects/whitej/esm_km_atp/images/zscore_percent_change_vs_alphamissense_score.png", bbox_inches = "tight")
plt.savefig(
"/data1/tanseyw/projects/whitej/esm_km_atp/images/zscore_percent_change_vs_alphamissense_score_log.png",
bbox_inches="tight",
)
# plt.savefig("/data1/tanseyw/projects/whitej/esm_km_atp/images/zscore_percent_change_vs_alphamissense_score.png", bbox_inches = "tight")
Loading

0 comments on commit 8bc152d

Please sign in to comment.