Skip to content

Commit

Permalink
ENH: Add 3-D version of sysu architecture.
Browse files Browse the repository at this point in the history
  • Loading branch information
ntustison committed Aug 29, 2023
1 parent b035d09 commit 0998a08
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 4 deletions.
2 changes: 1 addition & 1 deletion antspynet/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .create_custom_unet_model import create_nobrainer_unet_model_3d
from .create_custom_unet_model import create_hippmapp3r_unet_model_3d
from .create_custom_unet_model import create_hypermapp3r_unet_model_3d
from .create_custom_unet_model import create_sysu_media_unet_model_2d
from .create_custom_unet_model import create_sysu_media_unet_model_2d, create_sysu_media_unet_model_3d
from .create_custom_unet_model import create_hypothalamus_unet_model_3d
from .create_partial_convolution_unet_model import create_partial_convolution_unet_model_2d, create_partial_convolution_unet_model_3d
from .create_diffusion_probabilistic_unet_model import create_diffusion_probabilistic_unet_model_2d, create_diffusion_probabilistic_unet_model_3d
Expand Down
125 changes: 123 additions & 2 deletions antspynet/architectures/create_custom_unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Add, Activation, BatchNormalization, Concatenate, ReLU, LeakyReLU,
Conv3D, Conv3DTranspose, Input, Lambda, MaxPooling3D,
ReLU, SpatialDropout3D, UpSampling3D,
Conv3D, Cropping3D, Conv3DTranspose, Input, Lambda, MaxPooling3D,
ReLU, SpatialDropout3D, UpSampling3D, ZeroPadding3D,
Cropping2D, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, ZeroPadding2D)

from ..utilities import InstanceNormalization
Expand Down Expand Up @@ -569,6 +569,127 @@ def get_crop_shape(target_layer, reference_layer):

return(unet_model)

def create_sysu_media_unet_model_3d(input_image_size,
number_of_filters=None,
anatomy="wmh"):
"""
3-D variation of the sysu_media U-net architecture
Creates a keras model implementation of the u-net architecture
in the 2017 MICCAI WMH challenge by the sysu_medial team described
here:
https://pubmed.ncbi.nlm.nih.gov/30125711/
with the original implementation available here:
https://github.com/hongweilibran/wmh_ibbmTum
Arguments
---------
input_image_size : tuple
Tuple of length 4, (width, height, depth, channels).
anatomy : string
"wmh" or "claustrum"
Returns
-------
Keras model
A 3-D keras model defining the U-net network.
Example
-------
>>> image_size = (200, 200, 200)
>>> model = antspynet.create_sysu_media_unet_model_3d((*image_size, 1))
"""

def get_crop_shape(target_layer, reference_layer):

delta = K.int_shape(target_layer)[1] - K.int_shape(reference_layer)[1]
if delta % 2 != 0:
cropShape0 = (int(delta/2), int(delta/2) + 1)
else:
cropShape0 = (int(delta/2), int(delta/2))

delta = K.int_shape(target_layer)[2] - K.int_shape(reference_layer)[2]
if delta % 2 != 0:
cropShape1 = (int(delta/2), int(delta/2) + 1)
else:
cropShape1 = (int(delta/2), int(delta/2))

delta = K.int_shape(target_layer)[3] - K.int_shape(reference_layer)[3]
if delta % 2 != 0:
cropShape2 = (int(delta/2), int(delta/2) + 1)
else:
cropShape2 = (int(delta/2), int(delta/2))

return((cropShape0, cropShape1, cropShape2))

inputs = Input(shape=input_image_size)

if number_of_filters is None:
if anatomy == "wmh":
number_of_filters = (64, 96, 128, 256, 512)
elif anatomy == "claustrum":
number_of_filters = (32, 64, 96, 128, 256)

# encoding layers

encoding_layers = list()

outputs = inputs
for i in range(len(number_of_filters)):

kernel1 = 3
kernel2 = 3
if i == 0 and anatomy == "wmh":
kernel1 = 5
kernel2 = 5
elif i == 3:
kernel1 = 3
kernel2 = 4

outputs = Conv3D(filters=number_of_filters[i],
kernel_size=kernel1,
padding='same')(outputs)
outputs = Activation('relu')(outputs)
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=kernel2,
padding='same')(outputs)
outputs = Activation('relu')(outputs)
encoding_layers.append(outputs)
if i < 4:
outputs = MaxPooling3D(pool_size=(2, 2, 2))(outputs)

# decoding layers

for i in range(len(encoding_layers)-2, -1, -1):
upsample_layer = UpSampling3D(size=(2, 2, 2))(outputs)
crop_shape = get_crop_shape(encoding_layers[i], upsample_layer)
cropped_layer = Cropping3D(cropping=crop_shape)(encoding_layers[i])
outputs = Concatenate(axis=-1)([upsample_layer, cropped_layer])
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=3,
padding='same')(outputs)
outputs = Activation('relu')(outputs)
outputs = Conv3D(filters=number_of_filters[i],
kernel_size=3,
padding='same')(outputs)
outputs = Activation('relu')(outputs)

# final

crop_shape = get_crop_shape(inputs, outputs)
outputs = ZeroPadding3D(padding=crop_shape)(outputs)
outputs = Conv3D(filters=1,
kernel_size=1,
activation='sigmoid',
padding='same')(outputs)

unet_model = Model(inputs=inputs, outputs=outputs)

return(unet_model)

def create_hypothalamus_unet_model_3d(input_image_size):

Expand Down
2 changes: 1 addition & 1 deletion antspynet/architectures/create_unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def create_unet_model_2d(input_image_size,
specific configuration add-ons/tweaks:
* "attentionGating" -- attention-unet variant in https://pubmed.ncbi.nlm.nih.gov/33288961/
* "nnUnetActivationStyle" -- U-net activation explained in https://pubmed.ncbi.nlm.nih.gov/33288961/
* "initialConvolutionalKernelSize[X]" -- Set the first two convolutional layer kernel sizes to X.
* "initialConvolutionKernelSize[X]" -- Set the first two convolutional layer kernel sizes to X.
Returns
-------
Expand Down

0 comments on commit 0998a08

Please sign in to comment.