Skip to content

Commit

Permalink
nnunet model back to its own rule for workflow control
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordan DeKraker - B. Bernhardt Lab committed Nov 2, 2023
1 parent 7d81198 commit d836069
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
26 changes: 0 additions & 26 deletions hippunfold/workflow/rules/downloads.smk
Original file line number Diff line number Diff line change
@@ -1,29 +1,3 @@
def get_model_tar():
if config["force_nnunet_model"]:
model_name = config["force_nnunet_model"]
else:
model_name = config["modality"]

local_tar = config["nnunet_model"].get(model_name, None)
if local_tar == None:
print(f"ERROR: {model_name} does not exist in nnunet_model in the config file")

return os.path.abspath(os.path.join(download_dir, local_tar.split("/")[-1]))


rule download_model:
params:
url=config["nnunet_model"][config["force_nnunet_model"]]
if config["force_nnunet_model"]
else config["nnunet_model"][config["modality"]],
output:
model_tar=get_model_tar(),
container:
config["singularity"]["autotop"]
shell:
"wget https://{params.url} -O {output.model_tar}"


rule download_atlas:
params:
url=config["atlas_links_osf"][config["atlas"]],
Expand Down
26 changes: 26 additions & 0 deletions hippunfold/workflow/rules/nnunet.smk
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@ import re
from appdirs import AppDirs


def get_model_tar():
if config["force_nnunet_model"]:
model_name = config["force_nnunet_model"]
else:
model_name = config["modality"]

local_tar = config["nnunet_model"].get(model_name, None)
if local_tar == None:
print(f"ERROR: {model_name} does not exist in nnunet_model in the config file")

return os.path.abspath(os.path.join(download_dir, local_tar.split("/")[-1]))


rule download_model:
params:
url=config["nnunet_model"][config["force_nnunet_model"]]
if config["force_nnunet_model"]
else config["nnunet_model"][config["modality"]],
output:
model_tar=get_model_tar(),
container:
config["singularity"]["autotop"]
shell:
"wget https://{params.url} -O {output.model_tar}"


def get_nnunet_input(wildcards):
if config["modality"] == "T2w":
nii = (
Expand Down

0 comments on commit d836069

Please sign in to comment.