Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LiveCELL inference scripts #315

Merged
merged 4 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions finetuning/livecell/evaluation/evaluate_amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@

from micro_sam.evaluation.evaluation import run_evaluation
from micro_sam.evaluation.livecell import run_livecell_amg
from util import DATA_ROOT, get_checkpoint, get_experiment_folder, get_pred_and_gt_paths
from util import DATA_ROOT, get_pred_and_gt_paths


def run_amg(name, model_type, checkpoint):
if checkpoint is None:
checkpoint, model_type = get_checkpoint(name)
experiment_folder = get_experiment_folder(name)
def run_amg(model_type, checkpoint, experiment_folder):
input_folder = DATA_ROOT
prediction_folder = run_livecell_amg(
checkpoint,
Expand All @@ -21,28 +18,26 @@ def run_amg(name, model_type, checkpoint):
return prediction_folder


def eval_amg(name, prediction_folder):
def eval_amg(prediction_folder, experiment_folder):
print("Evaluating", prediction_folder)
pred_paths, gt_paths = get_pred_and_gt_paths(prediction_folder)
save_path = os.path.join(get_experiment_folder(name), "results", "amg.csv")
save_path = os.path.join(experiment_folder, "results", "amg.csv")
res = run_evaluation(gt_paths, pred_paths, save_path=save_path)
print(res)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", required=True)
parser.add_argument(
"-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist"
"-m", "--model", type=str, required=True,
help="Provide the model type to initialize the predictor"
)
parser.add_argument("-c", "--checkpoint", type=str, default=None)
parser.add_argument("-c", "--checkpoint", type=str, required=True)
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
args = parser.parse_args()

name = args.name

prediction_folder = run_amg(name, args.model, args.checkpoint)
eval_amg(name, prediction_folder)
prediction_folder = run_amg(args.model, args.checkpoint, args.experiment_folder)
eval_amg(prediction_folder, args.experiment_folder)


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions finetuning/livecell/evaluation/evaluate_amg.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#SBATCH -G A100:1
#SBATCH -A nim00007

source ~/.bashrc
micromamba activate main
python evaluate_amg.py $@
source activate sam
python evaluate_amg.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_b/livecell_sam/best.pt \
-m vit_b \
-e /scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_b/
25 changes: 11 additions & 14 deletions finetuning/livecell/evaluation/evaluate_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@

from micro_sam.evaluation.evaluation import run_evaluation
from micro_sam.evaluation.livecell import run_livecell_instance_segmentation_with_decoder
from util import DATA_ROOT, get_checkpoint, get_experiment_folder, get_pred_and_gt_paths
from util import DATA_ROOT, get_pred_and_gt_paths


def run_instance_segmentation_with_decoder(name, model_type, checkpoint):
if checkpoint is None:
checkpoint, model_type = get_checkpoint(name)
experiment_folder = get_experiment_folder(name)
def run_instance_segmentation_with_decoder(model_type, checkpoint, experiment_folder):
input_folder = DATA_ROOT
prediction_folder = run_livecell_instance_segmentation_with_decoder(
checkpoint,
Expand All @@ -21,28 +18,28 @@ def run_instance_segmentation_with_decoder(name, model_type, checkpoint):
return prediction_folder


def eval_instance_segmentation_with_decoder(name, prediction_folder):
def eval_instance_segmentation_with_decoder(prediction_folder, experiment_folder):
print("Evaluating", prediction_folder)
pred_paths, gt_paths = get_pred_and_gt_paths(prediction_folder)
save_path = os.path.join(get_experiment_folder(name), "results", "instance_segmentation_with_decoder.csv")
save_path = os.path.join(experiment_folder, "results", "instance_segmentation_with_decoder.csv")
res = run_evaluation(gt_paths, pred_paths, save_path=save_path)
print(res)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", required=True)
parser.add_argument(
"-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist"
"-m", "--model", type=str, required=True,
help="Provide the model type to initialize the predictor"
)
parser.add_argument("-c", "--checkpoint", type=str, default=None)
parser.add_argument("-c", "--checkpoint", type=str, required=True,)
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
args = parser.parse_args()

name = args.name

prediction_folder = run_instance_segmentation_with_decoder(name, args.model, args.checkpoint)
eval_instance_segmentation_with_decoder(name, prediction_folder)
prediction_folder = run_instance_segmentation_with_decoder(
args.model, args.checkpoint, args.experiment_folder
)
eval_instance_segmentation_with_decoder(prediction_folder, args.experiment_folder)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#SBATCH -G A100:1
#SBATCH -A nim00007

source ~/.bashrc
micromamba activate main
python evaluate_instance_segmentation.py $@
source activate sam
python evaluate_instance_segmentation.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt \
-m vit_h \
-e /scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_h/
32 changes: 10 additions & 22 deletions finetuning/livecell/evaluation/iterative_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from micro_sam.evaluation import inference
from micro_sam.evaluation.evaluation import run_evaluation
from util import get_paths, get_experiment_folder, get_model, get_pred_and_gt_paths
from util import get_paths, get_model, get_pred_and_gt_paths


def run_interactive_prompting(exp_folder, predictor, start_with_box_prompt):
Expand All @@ -26,16 +26,9 @@ def run_interactive_prompting(exp_folder, predictor, start_with_box_prompt):
return prediction_root


def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, name):
def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, exp_folder):
assert os.path.exists(prediction_root), prediction_root

csv_save_dir = f"./iterative_prompting_results/{name}"
os.makedirs(csv_save_dir, exist_ok=True)
csv_path = os.path.join(csv_save_dir, "start_with_box.csv" if start_with_box_prompt else "start_with_point.csv")
if os.path.exists(csv_path):
print("The evaluated results for the expected setting already exist here:", csv_path)
return

prediction_folders = sorted(glob(os.path.join(prediction_root, "iteration*")))
list_of_results = []
for pred_folder in prediction_folders:
Expand All @@ -46,10 +39,9 @@ def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, name)
print(res)

df = pd.concat(list_of_results, ignore_index=True)
df.to_csv(csv_path)

# Also save the results in the experiment folder.
result_folder = os.path.join(get_experiment_folder(name), "results")
# Save the results in the experiment folder.
result_folder = os.path.join(exp_folder, "results")
os.makedirs(result_folder, exist_ok=True)
csv_path = os.path.join(
result_folder,
Expand All @@ -60,25 +52,21 @@ def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, name)

def main():
parser = argparse.ArgumentParser()

parser.add_argument("-n", "--name", required=True)
parser.add_argument(
"-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist"
help="Provide the model type to initialize the predictor"
"-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor"
)
parser.add_argument("-c", "--checkpoint", type=str, default=None)
parser.add_argument("-c", "--checkpoint", type=str, required=True)
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box")
args = parser.parse_args()

name = args.name
start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point

# get the predictor to perform inference
predictor = get_model(name, model_type=args.model, ckpt=args.checkpoint)
predictor = get_model(model_type=args.model, ckpt=args.checkpoint)

exp_folder = get_experiment_folder(name)
prediction_root = run_interactive_prompting(exp_folder, predictor, start_with_box_prompt)
evaluate_interactive_prompting(prediction_root, start_with_box_prompt, name)
prediction_root = run_interactive_prompting(args.experiment_folder, predictor, start_with_box_prompt)
evaluate_interactive_prompting(prediction_root, start_with_box_prompt)


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions finetuning/livecell/evaluation/iterative_prompting.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#SBATCH -G A100:1
#SBATCH -A nim00007

source ~/.bashrc
micromamba activate main
python iterative_prompting.py $@
source activate sam
python iterative_prompting.py -c /scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt \
-m vit_h \
-e /scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_h/
85 changes: 85 additions & 0 deletions finetuning/livecell/evaluation/submit_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import shutil
import subprocess
from glob import glob
from datetime import datetime


def write_batch_script(env_name, out_path, inference_setup, checkpoint, model_type, experiment_folder):
"""Writing scripts with different fold-trainings for micro-sam evaluation
"""
batch_script = f"""#!/bin/bash
#SBATCH -c 8
#SBATCH --mem 128G
#SBATCH -t 6:00:00
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007
#SBATCH --job-name={inference_setup}

source ~/.bashrc
mamba activate {env_name}
python {inference_setup}.py """

_op = out_path[:-3] + f"_{inference_setup}.sh"

# add the finetuned checkpoint
batch_script += f"-c {checkpoint} "

# name of the model configuration
batch_script += f"-m {model_type} "

# experiment folder
batch_script += f"-e {experiment_folder} "

with open(_op, "w") as f:
f.write(batch_script)


def get_batch_script_names(tmp_folder):
tmp_folder = os.path.expanduser(tmp_folder)
os.makedirs(tmp_folder, exist_ok=True)

script_name = "livecell-inference"

dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
tmp_name = script_name + dt
batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh")

return batch_script


def submit_slurm():
"""Submit python script that needs gpus with given inputs on a slurm node.
"""
tmp_folder = "./gpu_jobs"

# parameters to run the inference scripts
env_name = "sam"
checkpoint = "/scratch/usr/nimanwai/micro-sam/checkpoints/vit_h/livecell_sam/best.pt"
model_type = "vit_h"
experiment_folder = "/scratch/projects/nim00007/sam/experiments/new_models/specialists/livecell/vit_h/"

all_setups = ["evaluate_amg", "evaluate_instance_segmentation", "iterative_prompting"]
for current_setup in all_setups:
write_batch_script(
env_name=env_name,
out_path=get_batch_script_names(tmp_folder),
inference_setup=current_setup,
checkpoint=checkpoint,
model_type=model_type,
experiment_folder=experiment_folder,
)

for my_script in glob(tmp_folder + "/*"):
cmd = ["sbatch", my_script]
subprocess.run(cmd)


if __name__ == "__main__":
try:
shutil.rmtree("./gpu_jobs")
except FileNotFoundError:
pass

submit_slurm()
2 changes: 1 addition & 1 deletion finetuning/livecell/evaluation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_checkpoint(name):
return ckpt, model_type


def get_model(name, model_type=None, ckpt=None):
def get_model(name=None, model_type=None, ckpt=None):
if ckpt is None:
ckpt, model_type = get_checkpoint(name)
assert (ckpt is not None) and (model_type is not None)
Expand Down