Skip to content

Commit

Permalink
moved download model rules back to nnunet
Browse files Browse the repository at this point in the history
avoids breaking the seg* modalities
  • Loading branch information
akhanf committed Feb 15, 2024
1 parent 5489c55 commit 628e515
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
28 changes: 0 additions & 28 deletions hippunfold/workflow/rules/download.smk
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,6 @@
download_dir = get_download_dir()


def get_model_tar():

if config["force_nnunet_model"]:
model_name = config["force_nnunet_model"]
else:
model_name = config["modality"]

local_tar = config["resource_urls"]["nnunet_model"].get(model_name, None)
if local_tar == None:
print(f"ERROR: {model_name} does not exist in nnunet_model in the config file")

return (Path(download_dir) / "model" / Path(local_tar).name).absolute()


rule download_nnunet_model:
params:
url=config["resource_urls"]["nnunet_model"][config["force_nnunet_model"]]
if config["force_nnunet_model"]
else config["resource_urls"]["nnunet_model"][config["modality"]],
model_dir=Path(download_dir) / "model",
output:
model_tar=get_model_tar(),
container:
config["singularity"]["autotop"]
shell:
"mkdir -p {params.model_dir} && wget https://{params.url} -O {output}"


rule download_extract_atlas_or_template:
params:
url=lambda wildcards: config["resource_urls"][wildcards.resource_type][
Expand Down
28 changes: 28 additions & 0 deletions hippunfold/workflow/rules/nnunet.smk
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@ def get_nnunet_input(wildcards):
return nii


def get_model_tar():

if config["force_nnunet_model"]:
model_name = config["force_nnunet_model"]
else:
model_name = config["modality"]

local_tar = config["resource_urls"]["nnunet_model"].get(model_name, None)
if local_tar == None:
print(f"ERROR: {model_name} does not exist in nnunet_model in the config file")

return (Path(download_dir) / "model" / Path(local_tar).name).absolute()


rule download_nnunet_model:
params:
url=config["resource_urls"]["nnunet_model"][config["force_nnunet_model"]]
if config["force_nnunet_model"]
else config["resource_urls"]["nnunet_model"][config["modality"]],
model_dir=Path(download_dir) / "model",
output:
model_tar=get_model_tar(),
container:
config["singularity"]["autotop"]
shell:
"mkdir -p {params.model_dir} && wget https://{params.url} -O {output}"


def parse_task_from_tar(wildcards, input):
match = re.search("Task[0-9]{3}_[\w]+", input.model_tar)
if match:
Expand Down
2 changes: 1 addition & 1 deletion hippunfold/workflow/rules/preproc_seg.smk
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ rule warp_seg_to_corobl_crop:
**config["subj_wildcards"],
suffix="dseg.nii.gz",
space="corobl",
hemi="{hemi}",
hemi="{hemi,L|R}",
from_="{space}"
),
container:
Expand Down

0 comments on commit 628e515

Please sign in to comment.