Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynunet model #873

Merged
merged 30 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
198bfbd
create the file dynunet and create the wrapper class
benmalef May 20, 2024
edbcad1
create the forward method
benmalef May 20, 2024
34b44b5
add dynunet's identifier in global_model_dict
benmalef May 20, 2024
a4cb6b2
make some code changes
benmalef May 20, 2024
b0776d9
add docker-compose file
benmalef May 21, 2024
f5d9316
create a dynunet example for demo
benmalef May 21, 2024
5df9810
defining the default parameters in dynunet model
benmalef May 21, 2024
5a25ebc
make some code changes
benmalef May 24, 2024
95b63ff
fix test_full error
benmalef May 24, 2024
bd99429
changes that made by black formatter
benmalef May 24, 2024
fe7f290
Merge branch 'mlcommons:master' into add_dynunet_model
benmalef May 24, 2024
2e64d2a
Merge branch 'new-apis_v0.1.0-dev' into add_dynunet_model
benmalef May 24, 2024
d016ddb
change the model name to dynunet_wrapper
benmalef May 24, 2024
033ee0f
delete the dynunet file
benmalef May 24, 2024
85cbfca
Merge branch 'new-apis_v0.1.0-dev' into add_dynunet_model
sarthakpati May 29, 2024
1591b6d
delete docker compose
benmalef May 30, 2024
abcb439
remove unnecessary comments
benmalef May 31, 2024
ddfeebb
black the __init__ model
benmalef May 31, 2024
5ae6aa3
changed test_full and config_segmentation
benmalef Jun 1, 2024
86eb2c5
changed the dynunet defaults parameters
benmalef Jun 1, 2024
0252b01
changed the config_segmentation and test_full
benmalef Jun 1, 2024
dec2b3b
black test_full
benmalef Jun 1, 2024
86e4188
fix test_full dynunet model error "kernel_size"
benmalef Jun 1, 2024
3af040c
black test_full
benmalef Jun 1, 2024
6d77337
Merge branch 'new-apis_v0.1.0-dev' into add_dynunet_model
benmalef Jun 1, 2024
d77cc8a
remove unnecessary comments
benmalef Jun 4, 2024
1dda4ed
made the proposed changes
benmalef Jun 5, 2024
826b80c
remove unnecessary if statements from dynunet model
benmalef Jun 5, 2024
4f460e1
Merge branch 'new-apis_v0.1.0-dev' into add_dynunet_model
sarthakpati Jun 5, 2024
48abb3e
Merge branch 'new-apis_v0.1.0-dev' into add_dynunet_model
sarthakpati Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions GANDLF/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .vgg import vgg11, vgg13, vgg16, vgg19
from .densenet import densenet121, densenet169, densenet201, densenet264
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
from .dynunet_wrapper import dynunet_wrapper
from .efficientnet import (
efficientnetB0,
efficientnetB1,
Expand Down Expand Up @@ -101,6 +102,8 @@
"efficientnetb5": efficientnetB5,
"efficientnetb6": efficientnetB6,
"efficientnetb7": efficientnetB7,
# dynunet
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
"dynunet": dynunet_wrapper,
# Custom models
"msdnet": MSDNet,
"brain_age": brainage,
Expand Down
97 changes: 97 additions & 0 deletions GANDLF/models/dynunet_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from .modelBase import ModelBase
import monai.networks.nets.dynunet as dynunet


class dynunet_wrapper(ModelBase):
"""
More info: https://docs.monai.io/en/stable/networks.html#dynunet

Args:
spatial_dims (int): number of spatial dimensions.
in_channels (int): number of input channels.
out_channels (int): number of output channels.
kernel_size (Sequence[Union[Sequence[int], int]]): convolution kernel size.
strides (Sequence[Union[Sequence[int], int]]): convolution strides for each blocks.
upsample_kernel_size (Sequence[Union[Sequence[int], int]]): convolution kernel size for transposed convolution layers. The values should equal to strides[1:].
filters (Optional[Sequence[int]]): number of output channels for each blocks. Defaults to None.
dropout (Union[Tuple, str, float, None]): dropout ratio. Defaults to no dropout.
norm_name (Union[Tuple, str]): feature normalization type and arguments. Defaults to INSTANCE.
act_name (Union[Tuple, str]): activation layer type and arguments. Defaults to leakyrelu.
deep_supervision (bool): whether to add deep supervision head before output. Defaults to False.
deep_supr_num (int): number of feature maps that will output during deep supervision head. The value should be larger than 0 and less than the number of up sample layers. Defaults to 1.
res_block (bool): whether to use residual connection based convolution blocks during the network. Defaults to False.
trans_bias (bool): whether to set the bias parameter in transposed convolution layers. Defaults to False.
"""

def __init__(self, parameters: dict):
super(dynunet_wrapper, self).__init__(parameters)

# checking for validation
assert (
"kernel_size" in parameters["model"]
) == True, "\033[0;31m`kernel_size` key missing in parameters"
assert (
"strides" in parameters["model"]
) == True, "\033[0;31m`strides` key missing in parameters"

# defining some defaults

if not ("upsample_kernel_size" in parameters["model"]):
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
parameters["model"]["upsample_kernel_size"] = parameters["model"][
"strides"
][1:]
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved

if not ("filters" in parameters["model"]):
parameters["model"]["filters"] = None
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved

if not ("act_name" in parameters["model"]):
parameters["model"]["act_name"] = (
"leakyrelu",
{"inplace": True, "negative_slope": 0.01},
)

if not ("deep_supervision" in parameters["model"]):
parameters["model"]["deep_supervision"] = False
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved

if not ("deep_supr_num" in parameters["model"]):
parameters["model"]["deep_supr_num"] = 1

if not ("res_block" in parameters["model"]):
parameters["model"]["res_block"] = False
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved

if not ("trans_bias" in parameters["model"]):
parameters["model"]["trans_bias"] = False

# if not ("norm_type" in parameters["model"]):
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
# self.norm_type = "INSTANCE"

if not ("dropout" in parameters):
parameters["model"]["dropout"] = None

self.model = dynunet.DynUNet(
spatial_dims=self.n_dimensions,
in_channels=self.n_channels,
out_channels=self.base_filters, # ? Is it correct?
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
kernel_size=parameters["model"]["kernel_size"],
strides=parameters["model"]["strides"],
upsample_kernel_size=parameters["model"][
"upsample_kernel_size"
], # The values should equal to strides[1:]
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
filters=parameters["model"][
"filters"
], # ? self.base_filter??? , number of output channels for each blocks
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
dropout=parameters["model"][
"dropout"
], # dropout ratio. Defaults to no dropout
norm_name=self.norm_type, # ? Is it correct??
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
act_name=parameters["model"]["act_name"],
deep_supervision=parameters["model"]["deep_supervision"],
deep_supr_num=parameters["model"][
"deep_supr_num"
], # number of feature maps that will output during deep supervision head.
res_block=parameters["model"]["res_block"],
trans_bias=parameters["model"]["trans_bias"],
)

def forward(self, x):
return self.model.forward(x)
90 changes: 37 additions & 53 deletions testing/config_segmentation.yaml
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
# Choose the segmentation model here
# Choose the segmentation model here
# options: unet, resunet, fcn
version:
{
minimum: 0.0.14,
maximum: 0.1.0-dev
}
model:
{
version: { minimum: 0.0.14, maximum: 0.1.0-dev }
model: {
dimension: 2, # the dimension of the model and dataset: defines dimensionality of computations
base_filters: 32, # Set base filters: number of filters present in the initial module of the U-Net convolution; for IncU-Net, keep this divisible by 4
architecture: resunet, # options: unet, resunet, fcn, uinc
final_layer: sigmoid, # can be either sigmoid, softmax or none (none == regression)
norm_type: instance, # can be either batch or instance
class_list: [0,255], # Set the list of labels the model should train on and predict
class_list: [0, 255], # Set the list of labels the model should train on and predict
amp: False, # Set if you want to use Automatic Mixed Precision for your operations or not - options: True, False
# n_channels: 3, # set the input channels - useful when reading RGB or images that have vectored pixel types
}
Expand All @@ -21,17 +16,12 @@ metrics:
- precision
- iou
- f1
- recall: {
average: macro,
}
- recall: { average: macro }
verbose: True
inference_mechanism: {
grid_aggregator_overlap: average,
patch_overlap: 0,
}
inference_mechanism: { grid_aggregator_overlap: average, patch_overlap: 0 }
modality: rad
# Patch size during training - 2D patch for breast images since third dimension is not patched
patch_size: [128,128]
# Patch size during training - 2D patch for breast images since third dimension is not patched
patch_size: [128, 128]
# Number of epochs
num_epochs: 1
patience: 1
Expand All @@ -51,58 +41,52 @@ optimizer: adam
# the value of 'k' for cross-validation, this is the percentage of total training data to use as validation;
# randomized split is performed using sklearn's KFold method
# for single fold run, use '-' before the fold number
nested_training:
{
nested_training: {
testing: -5, # this controls the holdout data splits for final model evaluation; use '1' if this is to be disabled
validation: -5 # this controls the validation data splits for model training
validation: -5, # this controls the validation data splits for model training
}
# various data augmentation techniques
# options: affine, elastic, downsample, motion, ghosting, bias, blur, gaussianNoise, swap
# keep/edit as needed
# all transforms: https://torchio.readthedocs.io/transforms/transforms.html?highlight=transforms
data_augmentation:
{
# 'spatial':{
# 'probability': 0.5
# },
# 'kspace':{
# 'probability': 0.5
# },
# 'bias':{
# 'probability': 0.5
# },
# 'blur':{
# 'probability': 0.5
# },
# 'noise':{
# 'probability': 0.5
# },
# 'swap':{
# 'probability': 0.5
# }
}
data_preprocessing:
{
data_augmentation: {}
# 'spatial':{
# 'probability': 0.5
# },
# 'kspace':{
# 'probability': 0.5
# },
# 'bias':{
# 'probability': 0.5
# },
# 'blur':{
# 'probability': 0.5
# },
# 'noise':{
# 'probability': 0.5
# },
# 'swap':{
# 'probability': 0.5
# }
data_preprocessing: {
# 'threshold':{
# 'min': 10,
# 'min': 10,
# 'max': 75
# },
# 'clip':{
# 'min': 10,
# 'min': 10,
# 'max': 75
# },
'normalize',
# },
"normalize",
# 'resample':{
# 'resolution': [1,2,3]
# },
#'resize': [128,128], # this is generally not recommended, as it changes image properties in unexpected ways
}
# data postprocessing node
data_postprocessing:
{
# 'largest_component',
# 'hole_filling'
}
data_postprocessing: {}
# 'largest_component',
# 'hole_filling'
# parallel training on HPC - here goes the command to prepend to send to a high performance computing
# cluster for parallel computing during multi-fold training
# not used for single fold training
Expand Down
72 changes: 71 additions & 1 deletion testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def write_temp_config_path(parameters_to_write):
return temp_config_path


# # these are helper functions to be used in other tests
# these are helper functions to be used in other tests


def test_train_segmentation_rad_2d(device):
Expand Down Expand Up @@ -3145,3 +3145,73 @@ def test_generic_data_split():
sanitize_outputDir()

print("passed")


def test_dynunet_model(device):
# is this test correct ??

# for more info : https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py
print("51: Starting test for Dynunet implementation")

# Reading the parameters
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_3d_rad_segmentation.csv"
)
parameters["model"]["architecture"] = "dynunet"
parameters["model"]["dimension"] = 3
parameters["patch_size"] = patch_size["3D"]
parameters["model"]["kernel_size"] = [3, 3, 3, 1]
parameters["model"]["strides"] = [1, 1, 1, 1]
sanitize_outputDir()
TrainingManager(
dataframe=training_data,
outputDir=outputDir,
parameters=parameters,
device=device,
resume=False,
reset=True,
)

sanitize_outputDir()

print("passed")


def test_dynunet_mandatory_input_values():
"""
It tests the mandatatory inputs for dynunet model exist in the parameters["model"].
Mandotary inputs:
- kernel_size
- strides

It checks If an AssertionError is raised.
"""
# Reading the parameters from segmentation config.

with pytest.raises(AssertionError):
# Reading the parameters from segmentation config.
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)

parameters["model"]["architecture"] = "dynunet"

training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_3d_rad_segmentation.csv"
)

sanitize_outputDir()
TrainingManager(
dataframe=training_data,
outputDir=outputDir,
parameters=parameters,
device=device,
resume=False,
reset=True,
)

sanitize_outputDir()
print("Passed")
sarthakpati marked this conversation as resolved.
Show resolved Hide resolved
Loading