diff --git a/GANDLF/models/dynunet_wrapper.py b/GANDLF/models/dynunet_wrapper.py index 3f209a3d8..176596bee 100644 --- a/GANDLF/models/dynunet_wrapper.py +++ b/GANDLF/models/dynunet_wrapper.py @@ -2,6 +2,41 @@ import monai.networks.nets.dynunet as dynunet +def get_kernels_strides(sizes, spacings): + """ + More info: https://github.com/Project-MONAI/tutorials/blob/main/modules/dynunet_pipeline/create_network.py#L19 + + When refering this method for other tasks, please ensure that the patch size for each spatial dimension should + be divisible by the product of all strides in the corresponding dimension. + In addition, the minimal spatial size should have at least one dimension that has twice the size of + the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised. + + """ + input_size = sizes + strides, kernels = [], [] + while True: + spacing_ratio = [sp / min(spacings) for sp in spacings] + stride = [ + 2 if ratio <= 2 and size >= 8 else 1 + for (ratio, size) in zip(spacing_ratio, sizes) + ] + kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] + if all(s == 1 for s in stride): + break + for idx, (i, j) in enumerate(zip(sizes, stride)): + assert ( + i % j == 0 + ), f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}." + sizes = [i / j for i, j in zip(sizes, stride)] + spacings = [i * j for i, j in zip(spacings, stride)] + kernels.append(kernel) + strides.append(stride) + + strides.insert(0, len(spacings) * [1]) + kernels.append(len(spacings) * [3]) + return kernels, strides + + class dynunet_wrapper(ModelBase): """ More info: https://docs.monai.io/en/stable/networks.html#dynunet @@ -26,35 +61,33 @@ class dynunet_wrapper(ModelBase): def __init__(self, parameters: dict): super(dynunet_wrapper, self).__init__(parameters) - # checking for validation - assert ( - "kernel_size" in parameters["model"] - ) == True, "\033[0;31m`kernel_size` key missing in parameters" - assert ( - "strides" in parameters["model"] - ) == True, "\033[0;31m`strides` key missing in parameters" - - # defining some defaults - # if not ("upsample_kernel_size" in parameters["model"]): - # parameters["model"]["upsample_kernel_size"] = parameters["model"][ - # "strides" - # ][1:] + patch_size = parameters.get("patch_size", None) + spacing = parameters.get( + "spacing_for_internal_computations", + [1.0 for i in range(parameters["model"]["dimension"])], + ) + parameters["model"]["kernel_size"] = parameters["model"].get( + "kernel_size", None + ) + parameters["model"]["strides"] = parameters["model"].get("strides", None) + if (parameters["model"]["kernel_size"] is None) or ( + parameters["model"]["strides"] is None + ): + kernel_size, strides = get_kernels_strides(patch_size, spacing) + parameters["model"]["kernel_size"] = kernel_size + parameters["model"]["strides"] = strides parameters["model"]["filters"] = parameters["model"].get("filters", None) parameters["model"]["act_name"] = parameters["model"].get( "act_name", ("leakyrelu", {"inplace": True, "negative_slope": 0.01}) ) - parameters["model"]["deep_supervision"] = parameters["model"].get( - "deep_supervision", True + "deep_supervision", False ) - parameters["model"]["deep_supr_num"] = parameters["model"].get( "deep_supr_num", 1 ) - parameters["model"]["res_block"] = parameters["model"].get("res_block", True) - parameters["model"]["trans_bias"] = parameters["model"].get("trans_bias", False) parameters["model"]["dropout"] = parameters["model"].get("dropout", None) diff --git a/testing/test_full.py b/testing/test_full.py index 28e1b9437..f78928ef5 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -276,12 +276,6 @@ def test_train_segmentation_rad_2d(device): ["acs", "soft", "conv3d"] ) - if model == "dynunet": - # More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py - parameters["model"]["kernel_size"] = (3, 3, 3, 1) - parameters["model"]["strides"] = (1, 1, 1, 1) - parameters["model"]["deep_supervision"] = False - parameters["model"]["architecture"] = model parameters["nested_training"]["testing"] = -5 parameters["nested_training"]["validation"] = -5 @@ -374,12 +368,6 @@ def test_train_segmentation_rad_3d(device): ["acs", "soft", "conv3d"] ) - if model == "dynunet": - # More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py - parameters["model"]["kernel_size"] = (3, 3, 3, 1) - parameters["model"]["strides"] = (1, 1, 1, 1) - parameters["model"]["deep_supervision"] = False - parameters["model"]["architecture"] = model parameters["nested_training"]["testing"] = -5 parameters["nested_training"]["validation"] = -5