Skip to content

Commit

Permalink
add auto configure kernel and strides
Browse files Browse the repository at this point in the history
  • Loading branch information
benmalef committed Sep 30, 2024
1 parent 2953a5e commit 140ff52
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 12 deletions.
34 changes: 33 additions & 1 deletion GANDLF/models/dynunet_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
from testing.test_full import patch_size
from .modelBase import ModelBase
import monai.networks.nets.dynunet as dynunet


def get_kernels_strides(sizes,spacings):
"""
https://github.com/Project-MONAI/tutorials/blob/main/modules/dynunet_pipeline/create_network.py#L19
"""
input_size = sizes
strides, kernels = [], []
while True:
spacing_ratio = [sp / min(spacings) for sp in spacings]
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
if all(s == 1 for s in stride):
break
for idx, (i, j) in enumerate(zip(sizes, stride)):
if i % j != 0:
raise ValueError(
f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
)
sizes = [i / j for i, j in zip(sizes, stride)]
spacings = [i * j for i, j in zip(spacings, stride)]
kernels.append(kernel)
strides.append(stride)

strides.insert(0, len(spacings) * [1])
kernels.append(len(spacings) * [3])
return kernels, strides

class dynunet_wrapper(ModelBase):
"""
More info: https://docs.monai.io/en/stable/networks.html#dynunet
Expand Down Expand Up @@ -57,7 +85,11 @@ def __init__(self, parameters: dict):

parameters["model"]["trans_bias"] = parameters["model"].get("trans_bias", False)
parameters["model"]["dropout"] = parameters["model"].get("dropout", None)

parameters["model"]["auto_calculation_kernel_stripes"] = parameters["model"].get("auto_calculation_kernel_stripes", True)
if parameters["model"]["auto_calculation_kernel_stripes"] == True:
patch_size = parameters.get("patch_size", None)
spacing = parameters.get("spacing", None)
parameters["model"]["kernel_size"], parameters["model"]["strides"] = get_kernels_strides(patch_size, spacing)
if not ("norm_type" in parameters["model"]):
self.norm_type = "INSTANCE"

Expand Down
24 changes: 13 additions & 11 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@
## global defines
# pre-defined segmentation model types for testing
all_models_segmentation = [
"lightunet",
"lightunet_multilayer",
"unet",
"unet_multilayer",
"deep_resunet",
"fcn",
"uinc",
"msdnet",
"imagenet_unet",
# "lightunet",
# "lightunet_multilayer",
# "unet",
# "unet_multilayer",
# "deep_resunet",
# "fcn",
# "uinc",
# "msdnet",
# "imagenet_unet",
"dynunet",
]
# pre-defined regression/classification model types for testing
Expand Down Expand Up @@ -278,8 +278,10 @@ def test_train_segmentation_rad_2d(device):

if model == "dynunet":
# More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py
parameters["model"]["kernel_size"] = (3, 3, 3, 1)
parameters["model"]["strides"] = (1, 1, 1, 1)
# parameters["model"]["kernel_size"] = (3, 3, 3, 1)
["model"]["auto_calculation_kernel_stripes"] = True
parameters["spacing"] = [[1.0, 1.0, 1.0]]
# parameters["model"]["strides"] = (1, 1, 1, 1)
parameters["model"]["deep_supervision"] = False

parameters["model"]["architecture"] = model
Expand Down

0 comments on commit 140ff52

Please sign in to comment.