Skip to content

Commit

Permalink
add auto configure kernel and strides
Browse files Browse the repository at this point in the history
  • Loading branch information
benmalef committed Oct 4, 2024
1 parent 140ff52 commit b04b8a3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 49 deletions.
51 changes: 23 additions & 28 deletions GANDLF/models/dynunet_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from testing.test_full import patch_size
from .modelBase import ModelBase
import monai.networks.nets.dynunet as dynunet


def get_kernels_strides(sizes,spacings):
def get_kernels_strides(sizes, spacings):
"""
https://github.com/Project-MONAI/tutorials/blob/main/modules/dynunet_pipeline/create_network.py#L19
More info: https://github.com/Project-MONAI/tutorials/blob/main/modules/dynunet_pipeline/create_network.py#L19
When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
be divisible by the product of all strides in the corresponding dimension.
In addition, the minimal spatial size should have at least one dimension that has twice the size of
the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.
"""
input_size = sizes
strides, kernels = [], []
while True:
spacing_ratio = [sp / min(spacings) for sp in spacings]
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
stride = [
2 if ratio <= 2 and size >= 8 else 1
for (ratio, size) in zip(spacing_ratio, sizes)
]
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
if all(s == 1 for s in stride):
break
Expand All @@ -30,6 +37,7 @@ def get_kernels_strides(sizes,spacings):
kernels.append(len(spacings) * [3])
return kernels, strides


class dynunet_wrapper(ModelBase):
"""
More info: https://docs.monai.io/en/stable/networks.html#dynunet
Expand All @@ -54,42 +62,29 @@ class dynunet_wrapper(ModelBase):
def __init__(self, parameters: dict):
super(dynunet_wrapper, self).__init__(parameters)

# checking for validation
assert (
"kernel_size" in parameters["model"]
) == True, "\033[0;31m`kernel_size` key missing in parameters"
assert (
"strides" in parameters["model"]
) == True, "\033[0;31m`strides` key missing in parameters"

# defining some defaults
# if not ("upsample_kernel_size" in parameters["model"]):
# parameters["model"]["upsample_kernel_size"] = parameters["model"][
# "strides"
# ][1:]

patch_size = parameters.get("patch_size", None)
spacing = parameters.get(
"spacing", [1.0 for i in range(parameters["model"]["dimension"])]
)
kernel_size, strides = get_kernels_strides(patch_size, spacing)
parameters["model"]["kernel_size"] = parameters["model"].get(
"kernel_size", kernel_size
)
parameters["model"]["strides"] = parameters["model"].get("strides", strides)
parameters["model"]["filters"] = parameters["model"].get("filters", None)
parameters["model"]["act_name"] = parameters["model"].get(
"act_name", ("leakyrelu", {"inplace": True, "negative_slope": 0.01})
)

parameters["model"]["deep_supervision"] = parameters["model"].get(
"deep_supervision", True
"deep_supervision", False
)

parameters["model"]["deep_supr_num"] = parameters["model"].get(
"deep_supr_num", 1
)

parameters["model"]["res_block"] = parameters["model"].get("res_block", True)

parameters["model"]["trans_bias"] = parameters["model"].get("trans_bias", False)
parameters["model"]["dropout"] = parameters["model"].get("dropout", None)
parameters["model"]["auto_calculation_kernel_stripes"] = parameters["model"].get("auto_calculation_kernel_stripes", True)
if parameters["model"]["auto_calculation_kernel_stripes"] == True:
patch_size = parameters.get("patch_size", None)
spacing = parameters.get("spacing", None)
parameters["model"]["kernel_size"], parameters["model"]["strides"] = get_kernels_strides(patch_size, spacing)

if not ("norm_type" in parameters["model"]):
self.norm_type = "INSTANCE"

Expand Down
31 changes: 10 additions & 21 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@
## global defines
# pre-defined segmentation model types for testing
all_models_segmentation = [
# "lightunet",
# "lightunet_multilayer",
# "unet",
# "unet_multilayer",
# "deep_resunet",
# "fcn",
# "uinc",
# "msdnet",
# "imagenet_unet",
"lightunet",
"lightunet_multilayer",
"unet",
"unet_multilayer",
"deep_resunet",
"fcn",
"uinc",
"msdnet",
"imagenet_unet",
"dynunet",
]
# pre-defined regression/classification model types for testing
Expand Down Expand Up @@ -276,13 +276,6 @@ def test_train_segmentation_rad_2d(device):
["acs", "soft", "conv3d"]
)

if model == "dynunet":
# More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py
# parameters["model"]["kernel_size"] = (3, 3, 3, 1)
["model"]["auto_calculation_kernel_stripes"] = True
parameters["spacing"] = [[1.0, 1.0, 1.0]]
# parameters["model"]["strides"] = (1, 1, 1, 1)
parameters["model"]["deep_supervision"] = False

parameters["model"]["architecture"] = model
parameters["nested_training"]["testing"] = -5
Expand All @@ -296,6 +289,7 @@ def test_train_segmentation_rad_2d(device):
resume=False,
reset=True,
)
print(parameters["model"]["strides"])

sanitize_outputDir()

Expand Down Expand Up @@ -376,11 +370,6 @@ def test_train_segmentation_rad_3d(device):
["acs", "soft", "conv3d"]
)

if model == "dynunet":
# More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py
parameters["model"]["kernel_size"] = (3, 3, 3, 1)
parameters["model"]["strides"] = (1, 1, 1, 1)
parameters["model"]["deep_supervision"] = False

parameters["model"]["architecture"] = model
parameters["nested_training"]["testing"] = -5
Expand Down

0 comments on commit b04b8a3

Please sign in to comment.