diff --git a/GANDLF/models/dynunet_wrapper.py b/GANDLF/models/dynunet_wrapper.py index 3f209a3d8..b91e61dbd 100644 --- a/GANDLF/models/dynunet_wrapper.py +++ b/GANDLF/models/dynunet_wrapper.py @@ -1,7 +1,35 @@ +from testing.test_full import patch_size from .modelBase import ModelBase import monai.networks.nets.dynunet as dynunet +def get_kernels_strides(sizes,spacings): + """ + https://github.com/Project-MONAI/tutorials/blob/main/modules/dynunet_pipeline/create_network.py#L19 + + """ + input_size = sizes + strides, kernels = [], [] + while True: + spacing_ratio = [sp / min(spacings) for sp in spacings] + stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] + kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] + if all(s == 1 for s in stride): + break + for idx, (i, j) in enumerate(zip(sizes, stride)): + if i % j != 0: + raise ValueError( + f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}." + ) + sizes = [i / j for i, j in zip(sizes, stride)] + spacings = [i * j for i, j in zip(spacings, stride)] + kernels.append(kernel) + strides.append(stride) + + strides.insert(0, len(spacings) * [1]) + kernels.append(len(spacings) * [3]) + return kernels, strides + class dynunet_wrapper(ModelBase): """ More info: https://docs.monai.io/en/stable/networks.html#dynunet @@ -57,7 +85,11 @@ def __init__(self, parameters: dict): parameters["model"]["trans_bias"] = parameters["model"].get("trans_bias", False) parameters["model"]["dropout"] = parameters["model"].get("dropout", None) - + parameters["model"]["auto_calculation_kernel_stripes"] = parameters["model"].get("auto_calculation_kernel_stripes", True) + if parameters["model"]["auto_calculation_kernel_stripes"] == True: + patch_size = parameters.get("patch_size", None) + spacing = parameters.get("spacing", None) + parameters["model"]["kernel_size"], parameters["model"]["strides"] = get_kernels_strides(patch_size, spacing) if not ("norm_type" in parameters["model"]): self.norm_type = "INSTANCE" diff --git a/testing/test_full.py b/testing/test_full.py index b36a8ab64..78523eab4 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -52,15 +52,15 @@ ## global defines # pre-defined segmentation model types for testing all_models_segmentation = [ - "lightunet", - "lightunet_multilayer", - "unet", - "unet_multilayer", - "deep_resunet", - "fcn", - "uinc", - "msdnet", - "imagenet_unet", + # "lightunet", + # "lightunet_multilayer", + # "unet", + # "unet_multilayer", + # "deep_resunet", + # "fcn", + # "uinc", + # "msdnet", + # "imagenet_unet", "dynunet", ] # pre-defined regression/classification model types for testing @@ -278,8 +278,10 @@ def test_train_segmentation_rad_2d(device): if model == "dynunet": # More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py - parameters["model"]["kernel_size"] = (3, 3, 3, 1) - parameters["model"]["strides"] = (1, 1, 1, 1) + # parameters["model"]["kernel_size"] = (3, 3, 3, 1) + ["model"]["auto_calculation_kernel_stripes"] = True + parameters["spacing"] = [[1.0, 1.0, 1.0]] + # parameters["model"]["strides"] = (1, 1, 1, 1) parameters["model"]["deep_supervision"] = False parameters["model"]["architecture"] = model