Skip to content

Commit

Permalink
make bar plots when run_features is called
Browse files Browse the repository at this point in the history
  • Loading branch information
Ritvik Vasan committed Nov 27, 2024
1 parent 54f52f9 commit e0e1254
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 35 deletions.
25 changes: 18 additions & 7 deletions src/br/analysis/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ def check_mig():
def get_mig_ids(gpu_uuid):
try:
# Get the list of GPUs
output = subprocess.check_output(['nvidia-smi','--query-gpu=,index,uuid' ,'--format=csv,noheader']).decode('utf-8').strip().split('\n')
output = (
subprocess.check_output(
["nvidia-smi", "--query-gpu=,index,uuid", "--format=csv,noheader"]
)
.decode("utf-8")
.strip()
.split("\n")
)

# Find the index of the specified GPU UUID
gpu_index = -1
Expand All @@ -71,7 +78,9 @@ def get_mig_ids(gpu_uuid):
# Now we need to get the MIG IDs for this GPU
mig_ids = []
# Run nvidia-smi command to get detailed information including MIG IDs
detailed_output = subprocess.check_output(['nvidia-smi', '-L']).decode('utf-8').strip().split('\n')
detailed_output = (
subprocess.check_output(["nvidia-smi", "-L"]).decode("utf-8").strip().split("\n")
)

# Flag to determine if we are in the right GPU section
in_gpu_section = False
Expand All @@ -82,11 +91,13 @@ def get_mig_ids(gpu_uuid):
break

# print(line)

if in_gpu_section:
# Check for MIG devices
if "MIG" in line:
mig_id = line.split('(')[1].split(')')[0].split(' ')[-1] # Assuming format is '.... MIG (UUID) ...'
mig_id = (
line.split("(")[1].split(")")[0].split(" ")[-1]
) # Assuming format is '.... MIG (UUID) ...'
mig_ids.append(mig_id.strip())

return mig_ids
Expand All @@ -105,14 +116,14 @@ def config_gpu():

for line in lines:
index, uuid, name, mem_used, mem_total = map(str.strip, line.split(","))
utilization = float(mem_used)*100/float(mem_total)
utilization = float(mem_used) * 100 / float(mem_total)

# Check if GPU utilization is under 20% (indicating it's idle)
if utilization < 20:
# print(uuid, utilization)
if is_mig:
mig_ids = get_mig_ids(uuid)

if mig_ids:
selected_gpu_id_or_uuid = mig_ids[0] # Select the first MIG ID
break # Exit the loop after finding the first MIG ID
Expand Down
25 changes: 12 additions & 13 deletions src/br/analysis/run_drugdata_analysis.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import argparse
import os
import sys
from pathlib import Path

from br.chandrasekaran_et_al.utils import _plot, perturbation_detection
from br.models.compute_features import get_embeddings
from br.models.utils import get_all_configs_per_dataset
from br.chandrasekaran_et_al.utils import perturbation_detection, _plot
import sys
import argparse


def _get_featurecols(df):
"""returna list of featuredata columns"""
"""returna list of featuredata columns."""
return [c for c in df.columns if "mu" in c]


def _get_featuredata(df):
"""return dataframe of just featuredata columns"""
"""return dataframe of just featuredata columns."""
return df[_get_featurecols(df)]


Expand All @@ -25,11 +26,9 @@ def main(args):
dataset_name = args.dataset_name
DATASET_INFO = get_all_configs_per_dataset(results_path)
dataset = DATASET_INFO[dataset_name]
run_names = dataset['names']
run_names = dataset["names"]

all_ret, df = get_embeddings(
run_names, args.dataset_name, DATASET_INFO, args.embeddings_path
)
all_ret, df = get_embeddings(run_names, args.dataset_name, DATASET_INFO, args.embeddings_path)
all_ret["well_position"] = "A0" # dummy
all_ret["Assay_Plate_Barcode"] = "Plate0" # dummy

Expand All @@ -41,10 +40,10 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script for computing perturbation detection metrics")
parser.add_argument(
"--save_path", type=str, required=True, help="Path to save the results."
parser = argparse.ArgumentParser(
description="Script for computing perturbation detection metrics"
)
parser.add_argument("--save_path", type=str, required=True, help="Path to save the results.")
parser.add_argument(
"--embeddings_path", type=str, required=True, help="Path to the saved embeddings."
)
Expand All @@ -63,4 +62,4 @@ def main(args):
cellpack dataset
python src/br/analysis/run_drugdata_analysis.py --save_path "./outputs_npm1_perturb/" --embeddings_path "./morphology_appropriate_representation_learning/model_embeddings/npm1_perturb/" --dataset_name "npm1_perturb"
"""
"""
3 changes: 2 additions & 1 deletion src/br/analysis/run_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
from pathlib import Path

from br.analysis.analysis_utils import setup_evaluation_params, setup_gpu, str2bool
from br.models.load_models import get_data_and_models
from br.models.save_embeddings import save_embeddings
Expand Down Expand Up @@ -31,7 +32,7 @@ def main(args):
skew_scale,
) = setup_evaluation_params(manifest, run_names)

# make save path directory
# make save path directory
Path(args.save_path).mkdir(parents=True, exist_ok=True)

# save embeddings for each model
Expand Down
13 changes: 6 additions & 7 deletions src/br/analysis/run_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import os
import sys
from pathlib import Path

import pandas as pd

from br.analysis.analysis_utils import (
get_feature_params,
setup_evaluation_params,
Expand Down Expand Up @@ -39,7 +41,7 @@ def main(args):
) = get_data_and_models(args.dataset_name, batch_size, config_path + "/results/", args.debug)
max_embed_dim = min(latent_dims)

# make save path directory
# make save path directory
Path(args.save_path).mkdir(parents=True, exist_ok=True)

# Save model sizes to CSV
Expand Down Expand Up @@ -81,9 +83,7 @@ def main(args):
classification_params,
evolve_params,
regression_params,
) = get_feature_params(
config_path + "/results/", args.dataset_name, manifest, keys, run_names
)
) = get_feature_params(config_path + "/results/", args.dataset_name, manifest, keys, run_names)

metric_list = [
"Rotation Invariance Error",
Expand All @@ -95,7 +95,6 @@ def main(args):
if regression_params["target_cols"]:
metric_list.append("Regression")


# Compute multi-metric benchmarking features
compute_features(
dataset=args.dataset_name,
Expand Down Expand Up @@ -126,12 +125,12 @@ def main(args):
csvs = [i.split(".")[0] for i in csvs]
# Remove non metric related csvs
csvs = [i for i in csvs if i not in run_names and i not in keys]
csvs = [i for i in csvs if i not in ['image', 'pcloud']]
csvs = [i for i in csvs if i not in ["image", "pcloud"]]
# classification and regression metrics are unique to each dataset
unique_metrics = [i for i in csvs if "classification" in i or "regression" in i]
# Collect dataframe and make plots
df, df_non_agg = collect_outputs(args.save_path, "std", run_names, csvs)
plot(args.save_path, df, run_names, args.dataset_name, "std", unique_metrics)
plot(args.save_path, df, run_names, args.dataset_name, "std", unique_metrics, df_non_agg)


if __name__ == "__main__":
Expand Down
15 changes: 11 additions & 4 deletions src/br/chandrasekaran_et_al/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import itertools
import os
from pathlib import Path
import seaborn as sns

import copairs.compute_np as backend
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pycytominer
import seaborn as sns
from copairs.compute import cosine_indexed
from copairs.map import (
aggregate,
Expand All @@ -18,8 +21,7 @@
from sklearn.metrics import average_precision_score
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import matplotlib.pyplot as plt
import pycytominer

from br.chandrasekaran_et_al import utils


Expand Down Expand Up @@ -172,7 +174,12 @@ def _plot(all_rep, save_path):
.sort_values(by="q_value")["Metadata_broad_sample"]
.values
)
ordered_drugs = all_rep.groupby(['Metadata_broad_sample']).mean().sort_values(by='q_value').reset_index()['Metadata_broad_sample']
ordered_drugs = (
all_rep.groupby(["Metadata_broad_sample"])
.mean()
.sort_values(by="q_value")
.reset_index()["Metadata_broad_sample"]
)
x_order = ordered_drugs

g = sns.catplot(
Expand Down
49 changes: 46 additions & 3 deletions src/br/features/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import mitsuba as mi

Expand Down Expand Up @@ -159,8 +160,8 @@ def plot(
title,
norm="std",
unique_expressivity_metrics=None,
df_non_agg=None,
):
import matplotlib as mpl

mpl.rcParams["pdf.fonttype"] = 42
df = df.dropna()
Expand Down Expand Up @@ -248,8 +249,50 @@ def plot(

fig.write_image(path / f"{title}.png", scale=3)
fig.write_image(path / f"{title}.pdf", scale=3)
# fig.write_image(path / f"{title}.eps", scale=2)
# fig.write_image(path / f"{title}.pdf")

if df_non_agg is not None:
sns.set(font_scale=1.1)
sns.set_theme(style="white")
for var in df_non_agg["variable"].unique():
this_df = df_non_agg.loc[df_non_agg["variable"] == var].reset_index(drop=True)
g = sns.catplot(
data=this_df,
y="model",
x="value",
kind="bar",
aspect=1.1,
height=2.6,
log=False,
errorbar="sd",
hue="model",
legend=False,
dodge=False,
palette=colors,
)

g.map(
sns.stripplot,
"value",
"model",
color="k",
dodge=True,
alpha=0.6,
ec="k",
linewidth=1,
s=1,
)

g.set(
xlim=[
np.quantile(this_df["value"].values, 0.05),
np.quantile(this_df["value"].values, 0.95),
]
)

g.set(yticklabels=[])

g.savefig(path / f"{var}.png", bbox_inches="tight", dpi=300)
g.savefig(path / f"{var}.pdf", bbox_inches="tight", dpi=300)


def plot_pc_saved(
Expand Down

0 comments on commit e0e1254

Please sign in to comment.