From 8fd1d6cfab85354bad8a19c762123acde54d6b5d Mon Sep 17 00:00:00 2001 From: Nick Tustison Date: Tue, 21 May 2024 09:00:56 -0700 Subject: [PATCH] WIP: Deep flash 2. --- antspynet/utilities/__init__.py | 1 - antspynet/utilities/deep_flash.py | 1503 +++++++++-------- antspynet/utilities/get_antsxnet_data.py | 4 + antspynet/utilities/get_pretrained_network.py | 4 + 4 files changed, 772 insertions(+), 740 deletions(-) diff --git a/antspynet/utilities/__init__.py b/antspynet/utilities/__init__.py index 8732398..bdb231e 100644 --- a/antspynet/utilities/__init__.py +++ b/antspynet/utilities/__init__.py @@ -98,7 +98,6 @@ from .hypothalamus_segmentation import hypothalamus_segmentation from .hippmapp3r_segmentation import hippmapp3r_segmentation from .deep_flash import deep_flash -from .deep_flash import deep_flash_deprecated from .deep_atropos import deep_atropos from .desikan_killiany_tourville_labeling import desikan_killiany_tourville_labeling from .cerebellum_morphology import cerebellum_morphology diff --git a/antspynet/utilities/deep_flash.py b/antspynet/utilities/deep_flash.py index a772b9d..608ae3a 100644 --- a/antspynet/utilities/deep_flash.py +++ b/antspynet/utilities/deep_flash.py @@ -6,12 +6,13 @@ from tensorflow.keras import regularizers def deep_flash(t1, - t2=None, - do_preprocessing=True, - use_rank_intensity=True, - antsxnet_cache_directory=None, - verbose=False - ): + t2=None, + which_parcellation="yassa", + do_preprocessing=True, + use_rank_intensity=True, + antsxnet_cache_directory=None, + verbose=False + ): """ Hippocampal/Enthorhinal segmentation using "Deep Flash" @@ -37,7 +38,7 @@ def deep_flash(t1, Label 16: right CA1 Label 17: left subiculum Label 18: right subiculum - + Preprocessing on the training data consisted of: * n4 bias correction, * affine registration to the "deep flash" template. @@ -49,15 +50,19 @@ def deep_flash(t1, raw or preprocessed 3-D T1-weighted brain image. t2 : ANTsImage - Optional 3-D T2-weighted brain image. If specified, it is assumed to be - pre-aligned to the t1. + Optional 3-D T2-weighted brain image for yassa parcellation. If + specified, it is assumed to be pre-aligned to the t1. + + which_parcellation : string --- "yassa" + See above label descriptions. do_preprocessing : boolean See description above. use_rank_intensity : boolean If false, use histogram matching with cropped template ROI. Otherwise, - use a rank intensity transform on the cropped ROI. + use a rank intensity transform on the cropped ROI. Only for "yassa" + parcellation. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. @@ -86,715 +91,671 @@ def deep_flash(t1, if t1.dimension != 3: raise ValueError("Image dimension must be 3.") - ################################ - # - # Options temporarily taken from the user - # - ################################ - - # use_hierarchical_parcellation : boolean - # If True, use u-net model with additional outputs of the medial temporal lobe - # region, hippocampal, and entorhinal/perirhinal/parahippocampal regions. Otherwise - # the only additional output is the medial temporal lobe. - # - # use_contralaterality : boolean - # Use both hemispherical models to also predict the corresponding contralateral - # segmentation and use both sets of priors to produce the results. - - use_hierarchical_parcellation = True - use_contralaterality = True - - ################################ - # - # Preprocess images - # - ################################ - - t1_preprocessed = t1 - t1_mask = None - t1_preprocessed_flipped = None - t1_template = ants.image_read(get_antsxnet_data("deepFlashTemplateT1SkullStripped")) - template_transforms = None - if do_preprocessing: + if which_parcellation == "yassa": - if verbose: - print("Preprocessing T1.") - - # Brain extraction - probability_mask = brain_extraction(t1_preprocessed, modality="t1", - antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) - t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0) - t1_preprocessed = t1_preprocessed * t1_mask - - # Do bias correction - t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) - - # Warp to template - registration = ants.registration(fixed=t1_template, moving=t1_preprocessed, - type_of_transform="antsRegistrationSyNQuickRepro[a]", verbose=verbose) - template_transforms = dict(fwdtransforms=registration['fwdtransforms'], - invtransforms=registration['invtransforms']) - t1_preprocessed = registration['warpedmovout'] - - if use_contralaterality: - t1_preprocessed_array = t1_preprocessed.numpy() - t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0) - t1_preprocessed_flipped = ants.from_numpy(t1_preprocessed_array_flipped, - origin=t1_preprocessed.origin, - spacing=t1_preprocessed.spacing, - direction=t1_preprocessed.direction) - - t2_preprocessed = t2 - t2_preprocessed_flipped = None - t2_template = None - if t2 is not None: - t2_template = ants.image_read(get_antsxnet_data("deepFlashTemplateT2SkullStripped")) - t2_template = ants.copy_image_info(t1_template, t2_template) + ################################ + # + # Options temporarily taken from the user + # + ################################ + + # use_hierarchical_parcellation : boolean + # If True, use u-net model with additional outputs of the medial temporal lobe + # region, hippocampal, and entorhinal/perirhinal/parahippocampal regions. Otherwise + # the only additional output is the medial temporal lobe. + # + # use_contralaterality : boolean + # Use both hemispherical models to also predict the corresponding contralateral + # segmentation and use both sets of priors to produce the results. + + use_hierarchical_parcellation = True + use_contralaterality = True + + ################################ + # + # Preprocess images + # + ################################ + + t1_preprocessed = t1 + t1_mask = None + t1_preprocessed_flipped = None + t1_template = ants.image_read(get_antsxnet_data("deepFlashTemplateT1SkullStripped")) + template_transforms = None if do_preprocessing: if verbose: - print("Preprocessing T2.") + print("Preprocessing T1.") # Brain extraction - t2_preprocessed = t2_preprocessed * t1_mask + probability_mask = brain_extraction(t1_preprocessed, modality="t1", + antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) + t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0) + t1_preprocessed = t1_preprocessed * t1_mask # Do bias correction - t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) + t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) # Warp to template - t2_preprocessed = ants.apply_transforms(fixed=t1_template, - moving=t2_preprocessed, transformlist=template_transforms['fwdtransforms'], - verbose=verbose) + registration = ants.registration(fixed=t1_template, moving=t1_preprocessed, + type_of_transform="antsRegistrationSyNQuickRepro[a]", verbose=verbose) + template_transforms = dict(fwdtransforms=registration['fwdtransforms'], + invtransforms=registration['invtransforms']) + t1_preprocessed = registration['warpedmovout'] if use_contralaterality: - t2_preprocessed_array = t2_preprocessed.numpy() - t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array, axis=0) - t2_preprocessed_flipped = ants.from_numpy(t2_preprocessed_array_flipped, - origin=t2_preprocessed.origin, - spacing=t2_preprocessed.spacing, - direction=t2_preprocessed.direction) - - probability_images = list() - labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) - image_size = (64, 64, 96) - - ################################ - # - # Process left/right in split networks - # - ################################ - - ################################ - # - # Download spatial priors - # - ################################ - - spatial_priors_file_name_path = get_antsxnet_data("deepFlashPriors", - antsxnet_cache_directory=antsxnet_cache_directory) - spatial_priors = ants.image_read(spatial_priors_file_name_path) - priors_image_list = ants.ndimage_to_list(spatial_priors) - for i in range(len(priors_image_list)): - priors_image_list[i] = ants.copy_image_info(t1_preprocessed, priors_image_list[i]) - - labels_left = labels[1::2] - priors_image_left_list = priors_image_list[1::2] - probability_images_left = list() - foreground_probability_images_left = list() - lower_bound_left = (76, 74, 56) - upper_bound_left = (140, 138, 152) - tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) - origin_left = tmp_cropped.origin - - spacing = tmp_cropped.spacing - direction = tmp_cropped.direction - - t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left, upper_bound_left) - t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min()) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0 - t2_template_roi_left = None - if t2_template is not None: - t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left, upper_bound_left) - t2_template_roi_left = (t2_template_roi_left - t2_template_roi_left.min()) / (t2_template_roi_left.max() - t2_template_roi_left.min()) * 2.0 - 1.0 - - labels_right = labels[2::2] - priors_image_right_list = priors_image_list[2::2] - probability_images_right = list() - foreground_probability_images_right = list() - lower_bound_right = (20, 74, 56) - upper_bound_right = (84, 138, 152) - tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) - origin_right = tmp_cropped.origin - - t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right, upper_bound_right) - t1_template_roi_right = (t1_template_roi_right - t1_template_roi_right.min()) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0 - t2_template_roi_right = None - if t2_template is not None: - t2_template_roi_right = ants.crop_indices(t2_template, lower_bound_right, upper_bound_right) - t2_template_roi_right = (t2_template_roi_right - t2_template_roi_right.min()) / (t2_template_roi_right.max() - t2_template_roi_right.min()) * 2.0 - 1.0 - - - ################################ - # - # Create model - # - ################################ - - channel_size = 1 + len(labels_left) - if t2 is not None: - channel_size += 1 - - number_of_classification_labels = 1 + len(labels_left) - - unet_model = create_unet_model_3d((*image_size, channel_size), - number_of_outputs=number_of_classification_labels, mode="classification", - number_of_filters=(32, 64, 96, 128, 256), - convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), - dropout_rate=0.0, weight_decay=0) - - penultimate_layer = unet_model.layers[-2].output - - # medial temporal lobe - output1 = Conv3D(filters=1, - kernel_size=(1, 1, 1), - activation='sigmoid', - kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) - - if use_hierarchical_parcellation: - - # EC, perirhinal, and parahippo. - output2 = Conv3D(filters=1, - kernel_size=(1, 1, 1), - activation='sigmoid', - kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) + t1_preprocessed_array = t1_preprocessed.numpy() + t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0) + t1_preprocessed_flipped = ants.from_numpy(t1_preprocessed_array_flipped, + origin=t1_preprocessed.origin, + spacing=t1_preprocessed.spacing, + direction=t1_preprocessed.direction) + + t2_preprocessed = t2 + t2_preprocessed_flipped = None + t2_template = None + if t2 is not None: + t2_template = ants.image_read(get_antsxnet_data("deepFlashTemplateT2SkullStripped")) + t2_template = ants.copy_image_info(t1_template, t2_template) + if do_preprocessing: - # Hippocampus - output3 = Conv3D(filters=1, + if verbose: + print("Preprocessing T2.") + + # Brain extraction + t2_preprocessed = t2_preprocessed * t1_mask + + # Do bias correction + t2_preprocessed = ants.n4_bias_field_correction(t2_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) + + # Warp to template + t2_preprocessed = ants.apply_transforms(fixed=t1_template, + moving=t2_preprocessed, transformlist=template_transforms['fwdtransforms'], + verbose=verbose) + + if use_contralaterality: + t2_preprocessed_array = t2_preprocessed.numpy() + t2_preprocessed_array_flipped = np.flip(t2_preprocessed_array, axis=0) + t2_preprocessed_flipped = ants.from_numpy(t2_preprocessed_array_flipped, + origin=t2_preprocessed.origin, + spacing=t2_preprocessed.spacing, + direction=t2_preprocessed.direction) + + probability_images = list() + labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + image_size = (64, 64, 96) + + ################################ + # + # Process left/right in split networks + # + ################################ + + ################################ + # + # Download spatial priors + # + ################################ + + spatial_priors_file_name_path = get_antsxnet_data("deepFlashPriors", + antsxnet_cache_directory=antsxnet_cache_directory) + spatial_priors = ants.image_read(spatial_priors_file_name_path) + priors_image_list = ants.ndimage_to_list(spatial_priors) + for i in range(len(priors_image_list)): + priors_image_list[i] = ants.copy_image_info(t1_preprocessed, priors_image_list[i]) + + labels_left = labels[1::2] + priors_image_left_list = priors_image_list[1::2] + probability_images_left = list() + foreground_probability_images_left = list() + lower_bound_left = (76, 74, 56) + upper_bound_left = (140, 138, 152) + tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) + origin_left = tmp_cropped.origin + + spacing = tmp_cropped.spacing + direction = tmp_cropped.direction + + t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left, upper_bound_left) + t1_template_roi_left = (t1_template_roi_left - t1_template_roi_left.min()) / (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0 + t2_template_roi_left = None + if t2_template is not None: + t2_template_roi_left = ants.crop_indices(t2_template, lower_bound_left, upper_bound_left) + t2_template_roi_left = (t2_template_roi_left - t2_template_roi_left.min()) / (t2_template_roi_left.max() - t2_template_roi_left.min()) * 2.0 - 1.0 + + labels_right = labels[2::2] + priors_image_right_list = priors_image_list[2::2] + probability_images_right = list() + foreground_probability_images_right = list() + lower_bound_right = (20, 74, 56) + upper_bound_right = (84, 138, 152) + tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) + origin_right = tmp_cropped.origin + + t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right, upper_bound_right) + t1_template_roi_right = (t1_template_roi_right - t1_template_roi_right.min()) / (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0 + t2_template_roi_right = None + if t2_template is not None: + t2_template_roi_right = ants.crop_indices(t2_template, lower_bound_right, upper_bound_right) + t2_template_roi_right = (t2_template_roi_right - t2_template_roi_right.min()) / (t2_template_roi_right.max() - t2_template_roi_right.min()) * 2.0 - 1.0 + + + ################################ + # + # Create model + # + ################################ + + channel_size = 1 + len(labels_left) + if t2 is not None: + channel_size += 1 + + number_of_classification_labels = 1 + len(labels_left) + + unet_model = create_unet_model_3d((*image_size, channel_size), + number_of_outputs=number_of_classification_labels, mode="classification", + number_of_filters=(32, 64, 96, 128, 256), + convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), + dropout_rate=0.0, weight_decay=0) + + penultimate_layer = unet_model.layers[-2].output + + # medial temporal lobe + output1 = Conv3D(filters=1, kernel_size=(1, 1, 1), activation='sigmoid', kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) - unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1, output2, output3]) - else: - unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1]) - - ################################ - # - # Left: build model and load weights - # - ################################ - - network_name = 'deepFlashLeftT1' - if t2 is not None: - network_name = 'deepFlashLeftBoth' - - if use_hierarchical_parcellation: - network_name += "Hierarchical" - - if use_rank_intensity: - network_name += "_ri" - - if verbose: - print("DeepFlash: retrieving model weights (left).") - weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) - unet_model.load_weights(weights_file_name) - - ################################ - # - # Left: do prediction and normalize to native space - # - ################################ - - if verbose: - print("Prediction (left).") - - batchX = None - if use_contralaterality: - batchX = np.zeros((2, *image_size, channel_size)) - else: - batchX = np.zeros((1, *image_size, channel_size)) - - t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) - if use_rank_intensity: - t1_cropped = ants.rank_intensity(t1_cropped) - else: - t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) - batchX[0,:,:,:,0] = t1_cropped.numpy() - if use_contralaterality: - t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_left, upper_bound_left) - if use_rank_intensity: - t1_cropped = ants.rank_intensity(t1_cropped) - else: - t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) - batchX[1,:,:,:,0] = t1_cropped.numpy() - if t2 is not None: - t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left, upper_bound_left) - if use_rank_intensity: - t2_cropped = ants.rank_intensity(t2_cropped) - else: - t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) - batchX[0,:,:,:,1] = t2_cropped.numpy() - if use_contralaterality: - t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_left, upper_bound_left) - if use_rank_intensity: - t2_cropped = ants.rank_intensity(t2_cropped) - else: - t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) - batchX[1,:,:,:,1] = t2_cropped.numpy() + if use_hierarchical_parcellation: - for i in range(len(priors_image_left_list)): - cropped_prior = ants.crop_indices(priors_image_left_list[i], lower_bound_left, upper_bound_left) - for j in range(batchX.shape[0]): - batchX[j,:,:,:,i + (channel_size - len(labels_left))] = cropped_prior.numpy() + # EC, perirhinal, and parahippo. + output2 = Conv3D(filters=1, + kernel_size=(1, 1, 1), + activation='sigmoid', + kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) - predicted_data = unet_model.predict(batchX, verbose=verbose) + # Hippocampus + output3 = Conv3D(filters=1, + kernel_size=(1, 1, 1), + activation='sigmoid', + kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) - for i in range(1 + len(labels_left)): - for j in range(predicted_data[0].shape[0]): - probability_image = \ - ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), - origin=origin_left, spacing=spacing, direction=direction) - if i > 0: - probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) - else: - probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) + unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1, output2, output3]) + else: + unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1]) - if j == 1: # flipped - probability_array_flipped = np.flip(probability_image.numpy(), axis=0) - probability_image = ants.from_numpy(probability_array_flipped, - origin=probability_image.origin, spacing=probability_image.spacing, - direction=probability_image.direction) + ################################ + # + # Left: build model and load weights + # + ################################ - if do_preprocessing: - probability_image = ants.apply_transforms(fixed=t1, - moving=probability_image, - transformlist=template_transforms['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose) - - if j == 0: # not flipped - probability_images_left.append(probability_image) - else: # flipped - probability_images_right.append(probability_image) - - - ################################ - # - # Left: do prediction of mtl, hippocampal, and ec regions and normalize to native space - # - ################################ - - for i in range(1, len(predicted_data)): - for j in range(predicted_data[i].shape[0]): - probability_image = \ - ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), - origin=origin_left, spacing=spacing, direction=direction) - probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) - - if j == 1: # flipped - probability_array_flipped = np.flip(probability_image.numpy(), axis=0) - probability_image = ants.from_numpy(probability_array_flipped, - origin=probability_image.origin, spacing=probability_image.spacing, - direction=probability_image.direction) + network_name = 'deepFlashLeftT1' + if t2 is not None: + network_name = 'deepFlashLeftBoth' - if do_preprocessing: - probability_image = ants.apply_transforms(fixed=t1, - moving=probability_image, - transformlist=template_transforms['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose) + if use_hierarchical_parcellation: + network_name += "Hierarchical" - if j == 0: # not flipped - foreground_probability_images_left.append(probability_image) - else: - foreground_probability_images_right.append(probability_image) - - ################################ - # - # Right: build model and load weights - # - ################################ - - network_name = 'deepFlashRightT1' - if t2 is not None: - network_name = 'deepFlashRightBoth' - - if use_hierarchical_parcellation: - network_name += "Hierarchical" - - if use_rank_intensity: - network_name += "_ri" - - if verbose: - print("DeepFlash: retrieving model weights (right).") - weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) - unet_model.load_weights(weights_file_name) - - ################################ - # - # Right: do prediction and normalize to native space - # - ################################ - - if verbose: - print("Prediction (right).") - - batchX = None - if use_contralaterality: - batchX = np.zeros((2, *image_size, channel_size)) - else: - batchX = np.zeros((1, *image_size, channel_size)) - - t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) - if use_rank_intensity: - t1_cropped = ants.rank_intensity(t1_cropped) - else: - t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) - batchX[0,:,:,:,0] = t1_cropped.numpy() - if use_contralaterality: - t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_right, upper_bound_right) if use_rank_intensity: - t1_cropped = ants.rank_intensity(t1_cropped) + network_name += "_ri" + + if verbose: + print("DeepFlash: retrieving model weights (left).") + weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) + unet_model.load_weights(weights_file_name) + + ################################ + # + # Left: do prediction and normalize to native space + # + ################################ + + if verbose: + print("Prediction (left).") + + batchX = None + if use_contralaterality: + batchX = np.zeros((2, *image_size, channel_size)) else: - t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) - batchX[1,:,:,:,0] = t1_cropped.numpy() - if t2 is not None: - t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right, upper_bound_right) + batchX = np.zeros((1, *image_size, channel_size)) + + t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) if use_rank_intensity: - t2_cropped = ants.rank_intensity(t2_cropped) + t1_cropped = ants.rank_intensity(t1_cropped) else: - t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) - batchX[0,:,:,:,1] = t2_cropped.numpy() + t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) + batchX[0,:,:,:,0] = t1_cropped.numpy() if use_contralaterality: - t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_right, upper_bound_right) + t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_left, upper_bound_left) + if use_rank_intensity: + t1_cropped = ants.rank_intensity(t1_cropped) + else: + t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) + batchX[1,:,:,:,0] = t1_cropped.numpy() + if t2 is not None: + t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_left, upper_bound_left) if use_rank_intensity: t2_cropped = ants.rank_intensity(t2_cropped) else: - t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) - batchX[1,:,:,:,1] = t2_cropped.numpy() - - for i in range(len(priors_image_right_list)): - cropped_prior = ants.crop_indices(priors_image_right_list[i], lower_bound_right, upper_bound_right) - for j in range(batchX.shape[0]): - batchX[j,:,:,:,i + (channel_size - len(labels_right))] = cropped_prior.numpy() - - predicted_data = unet_model.predict(batchX, verbose=verbose) + t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) + batchX[0,:,:,:,1] = t2_cropped.numpy() + if use_contralaterality: + t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_left, upper_bound_left) + if use_rank_intensity: + t2_cropped = ants.rank_intensity(t2_cropped) + else: + t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_left, 255, 64, False) + batchX[1,:,:,:,1] = t2_cropped.numpy() - for i in range(1 + len(labels_right)): - for j in range(predicted_data[0].shape[0]): - probability_image = \ - ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), - origin=origin_right, spacing=spacing, direction=direction) - if i > 0: - probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) - else: - probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) + for i in range(len(priors_image_left_list)): + cropped_prior = ants.crop_indices(priors_image_left_list[i], lower_bound_left, upper_bound_left) + for j in range(batchX.shape[0]): + batchX[j,:,:,:,i + (channel_size - len(labels_left))] = cropped_prior.numpy() - if j == 1: # flipped - probability_array_flipped = np.flip(probability_image.numpy(), axis=0) - probability_image = ants.from_numpy(probability_array_flipped, - origin=probability_image.origin, spacing=probability_image.spacing, - direction=probability_image.direction) + predicted_data = unet_model.predict(batchX, verbose=verbose) - if do_preprocessing: - probability_image = ants.apply_transforms(fixed=t1, - moving=probability_image, - transformlist=template_transforms['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose) - - if j == 0: # not flipped - if use_contralaterality: - probability_images_right[i] = (probability_images_right[i] + probability_image) / 2 + for i in range(1 + len(labels_left)): + for j in range(predicted_data[0].shape[0]): + probability_image = \ + ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), + origin=origin_left, spacing=spacing, direction=direction) + if i > 0: + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) else: + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) + + if j == 1: # flipped + probability_array_flipped = np.flip(probability_image.numpy(), axis=0) + probability_image = ants.from_numpy(probability_array_flipped, + origin=probability_image.origin, spacing=probability_image.spacing, + direction=probability_image.direction) + + if do_preprocessing: + probability_image = ants.apply_transforms(fixed=t1, + moving=probability_image, + transformlist=template_transforms['invtransforms'], + whichtoinvert=[True], interpolator="linear", verbose=verbose) + + if j == 0: # not flipped + probability_images_left.append(probability_image) + else: # flipped probability_images_right.append(probability_image) - else: # flipped - probability_images_left[i] = (probability_images_left[i] + probability_image) / 2 - ################################ - # - # Right: do prediction of mtl, hippocampal, and ec regions and normalize to native space - # - ################################ + ################################ + # + # Left: do prediction of mtl, hippocampal, and ec regions and normalize to native space + # + ################################ - for i in range(1, len(predicted_data)): - for j in range(predicted_data[i].shape[0]): - probability_image = \ - ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), - origin=origin_right, spacing=spacing, direction=direction) - probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) + for i in range(1, len(predicted_data)): + for j in range(predicted_data[i].shape[0]): + probability_image = \ + ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), + origin=origin_left, spacing=spacing, direction=direction) + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) - if j == 1: # flipped - probability_array_flipped = np.flip(probability_image.numpy(), axis=0) - probability_image = ants.from_numpy(probability_array_flipped, - origin=probability_image.origin, spacing=probability_image.spacing, - direction=probability_image.direction) + if j == 1: # flipped + probability_array_flipped = np.flip(probability_image.numpy(), axis=0) + probability_image = ants.from_numpy(probability_array_flipped, + origin=probability_image.origin, spacing=probability_image.spacing, + direction=probability_image.direction) - if do_preprocessing: - probability_image = ants.apply_transforms(fixed=t1, - moving=probability_image, - transformlist=template_transforms['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose) - - if j == 0: # not flipped - if use_contralaterality: - foreground_probability_images_right[i-1] = (foreground_probability_images_right[i-1] + probability_image) / 2 + if do_preprocessing: + probability_image = ants.apply_transforms(fixed=t1, + moving=probability_image, + transformlist=template_transforms['invtransforms'], + whichtoinvert=[True], interpolator="linear", verbose=verbose) + + if j == 0: # not flipped + foreground_probability_images_left.append(probability_image) else: foreground_probability_images_right.append(probability_image) - else: - foreground_probability_images_left[i-1] = (foreground_probability_images_left[i-1] + probability_image) / 2 - - ################################ - # - # Combine priors - # - ################################ - - probability_background_image = ants.image_clone(t1) * 0 - for i in range(1, len(probability_images_left)): - probability_background_image += probability_images_left[i] - for i in range(1, len(probability_images_right)): - probability_background_image += probability_images_right[i] - - probability_images.append(probability_background_image * -1 + 1) - for i in range(1, len(probability_images_left)): - probability_images.append(probability_images_left[i]) - probability_images.append(probability_images_right[i]) - - ################################ - # - # Convert probability images to segmentation - # - ################################ - - # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) - # segmentation_matrix = np.argmax(image_matrix, axis=0) - # segmentation_image = ants.matrix_to_images( - # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] - - image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], t1 * 0 + 1) - background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), - np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) - foreground_matrix = np.argmax(background_foreground_matrix, axis=0) - segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix - segmentation_image = ants.matrix_to_images( - np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] - - relabeled_image = ants.image_clone(segmentation_image) - for i in range(len(labels)): - relabeled_image[segmentation_image==i] = labels[i] - - foreground_probability_images = list() - for i in range(len(foreground_probability_images_left)): - foreground_probability_images.append(foreground_probability_images_left[i] + foreground_probability_images_right[i]) - - return_dict = None - if use_hierarchical_parcellation: - return_dict = {'segmentation_image' : relabeled_image, - 'probability_images' : probability_images, - 'medial_temporal_lobe_probability_image' : foreground_probability_images[0], - 'other_region_probability_image' : foreground_probability_images[1], - 'hippocampal_probability_image' : foreground_probability_images[2] - } - else: - return_dict = {'segmentation_image' : relabeled_image, - 'probability_images' : probability_images, - 'medial_temporal_lobe_probability_image' : foreground_probability_images[0] - } - return(return_dict) + ################################ + # + # Right: build model and load weights + # + ################################ + network_name = 'deepFlashRightT1' + if t2 is not None: + network_name = 'deepFlashRightBoth' -def deep_flash_deprecated(t1, - do_preprocessing=True, - do_per_hemisphere=True, - which_hemisphere_models="new", - antsxnet_cache_directory=None, - verbose=False - ): + if use_hierarchical_parcellation: + network_name += "Hierarchical" - """ - Hippocampal/Enthorhinal segmentation using "Deep Flash" + if use_rank_intensity: + network_name += "_ri" - Perform hippocampal/entorhinal segmentation in T1 images using - labels from Mike Yassa's lab + if verbose: + print("DeepFlash: retrieving model weights (right).") + weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) + unet_model.load_weights(weights_file_name) - https://faculty.sites.uci.edu/myassa/ + ################################ + # + # Right: do prediction and normalize to native space + # + ################################ - The labeling is as follows: - Label 0 : background - Label 5 : left aLEC - Label 6 : right aLEC - Label 7 : left pMEC - Label 8 : right pMEC - Label 9 : left perirhinal - Label 10: right perirhinal - Label 11: left parahippocampal - Label 12: right parahippocampal - Label 13: left DG/CA3 - Label 14: right DG/CA3 - Label 15: left CA1 - Label 16: right CA1 - Label 17: left subiculum - Label 18: right subiculum + if verbose: + print("Prediction (right).") - Preprocessing on the training data consisted of: - * n4 bias correction, - * denoising, - * brain extraction, and - * affine registration to MNI. - The input T1 should undergo the same steps. If the input T1 is the raw - T1, these steps can be performed by the internal preprocessing, i.e. set - do_preprocessing = True + batchX = None + if use_contralaterality: + batchX = np.zeros((2, *image_size, channel_size)) + else: + batchX = np.zeros((1, *image_size, channel_size)) - Arguments - --------- - t1 : ANTsImage - raw or preprocessed 3-D T1-weighted brain image. + t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) + if use_rank_intensity: + t1_cropped = ants.rank_intensity(t1_cropped) + else: + t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) + batchX[0,:,:,:,0] = t1_cropped.numpy() + if use_contralaterality: + t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_right, upper_bound_right) + if use_rank_intensity: + t1_cropped = ants.rank_intensity(t1_cropped) + else: + t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) + batchX[1,:,:,:,0] = t1_cropped.numpy() + if t2 is not None: + t2_cropped = ants.crop_indices(t2_preprocessed, lower_bound_right, upper_bound_right) + if use_rank_intensity: + t2_cropped = ants.rank_intensity(t2_cropped) + else: + t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) + batchX[0,:,:,:,1] = t2_cropped.numpy() + if use_contralaterality: + t2_cropped = ants.crop_indices(t2_preprocessed_flipped, lower_bound_right, upper_bound_right) + if use_rank_intensity: + t2_cropped = ants.rank_intensity(t2_cropped) + else: + t2_cropped = ants.histogram_match_image(t2_cropped, t2_template_roi_right, 255, 64, False) + batchX[1,:,:,:,1] = t2_cropped.numpy() - do_preprocessing : boolean - See description above. + for i in range(len(priors_image_right_list)): + cropped_prior = ants.crop_indices(priors_image_right_list[i], lower_bound_right, upper_bound_right) + for j in range(batchX.shape[0]): + batchX[j,:,:,:,i + (channel_size - len(labels_right))] = cropped_prior.numpy() - do_per_hemisphere : boolean - If True, do prediction based on separate networks per hemisphere. Otherwise, - use the single network trained for both hemispheres. + predicted_data = unet_model.predict(batchX, verbose=verbose) - antsxnet_cache_directory : string - Destination directory for storing the downloaded template and model weights. - Since these can be reused, if is None, these data will be downloaded to a - ~/.keras/ANTsXNet/. + for i in range(1 + len(labels_right)): + for j in range(predicted_data[0].shape[0]): + probability_image = \ + ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), + origin=origin_right, spacing=spacing, direction=direction) + if i > 0: + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) + else: + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) - verbose : boolean - Print progress to the screen. + if j == 1: # flipped + probability_array_flipped = np.flip(probability_image.numpy(), axis=0) + probability_image = ants.from_numpy(probability_array_flipped, + origin=probability_image.origin, spacing=probability_image.spacing, + direction=probability_image.direction) - Returns - ------- - List consisting of the segmentation image and probability images for - each label. + if do_preprocessing: + probability_image = ants.apply_transforms(fixed=t1, + moving=probability_image, + transformlist=template_transforms['invtransforms'], + whichtoinvert=[True], interpolator="linear", verbose=verbose) - Example - ------- - >>> image = ants.image_read("t1.nii.gz") - >>> flash = deep_flash(image) - """ + if j == 0: # not flipped + if use_contralaterality: + probability_images_right[i] = (probability_images_right[i] + probability_image) / 2 + else: + probability_images_right.append(probability_image) + else: # flipped + probability_images_left[i] = (probability_images_left[i] + probability_image) / 2 - from ..architectures import create_unet_model_3d - from ..utilities import get_pretrained_network - from ..utilities import get_antsxnet_data - from ..utilities import preprocess_brain_image - from ..utilities import pad_or_crop_image_to_size - print("This function is deprecated. Please update to deep_flash().") + ################################ + # + # Right: do prediction of mtl, hippocampal, and ec regions and normalize to native space + # + ################################ + + for i in range(1, len(predicted_data)): + for j in range(predicted_data[i].shape[0]): + probability_image = \ + ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), + origin=origin_right, spacing=spacing, direction=direction) + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) - if t1.dimension != 3: - raise ValueError("Image dimension must be 3.") + if j == 1: # flipped + probability_array_flipped = np.flip(probability_image.numpy(), axis=0) + probability_image = ants.from_numpy(probability_array_flipped, + origin=probability_image.origin, spacing=probability_image.spacing, + direction=probability_image.direction) + + if do_preprocessing: + probability_image = ants.apply_transforms(fixed=t1, + moving=probability_image, + transformlist=template_transforms['invtransforms'], + whichtoinvert=[True], interpolator="linear", verbose=verbose) + + if j == 0: # not flipped + if use_contralaterality: + foreground_probability_images_right[i-1] = (foreground_probability_images_right[i-1] + probability_image) / 2 + else: + foreground_probability_images_right.append(probability_image) + else: + foreground_probability_images_left[i-1] = (foreground_probability_images_left[i-1] + probability_image) / 2 - if antsxnet_cache_directory == None: - antsxnet_cache_directory = "ANTsXNet" + ################################ + # + # Combine priors + # + ################################ + + probability_background_image = ants.image_clone(t1) * 0 + for i in range(1, len(probability_images_left)): + probability_background_image += probability_images_left[i] + for i in range(1, len(probability_images_right)): + probability_background_image += probability_images_right[i] - ################################ - # - # Preprocess images - # - ################################ + probability_images.append(probability_background_image * -1 + 1) + for i in range(1, len(probability_images_left)): + probability_images.append(probability_images_left[i]) + probability_images.append(probability_images_right[i]) - t1_preprocessed = t1 - if do_preprocessing: - t1_preprocessing = preprocess_brain_image(t1, - truncate_intensity=(0.01, 0.99), - brain_extraction_modality="t1", - template="croppedMni152", - template_transform_type="antsRegistrationSyNQuickRepro[a]", - do_bias_correction=True, - do_denoising=True, - antsxnet_cache_directory=antsxnet_cache_directory, - verbose=verbose) - t1_preprocessed = t1_preprocessing["preprocessed_image"] * t1_preprocessing['brain_mask'] + ################################ + # + # Convert probability images to segmentation + # + ################################ + + # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) + # segmentation_matrix = np.argmax(image_matrix, axis=0) + # segmentation_image = ants.matrix_to_images( + # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] - probability_images = list() - labels = (0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], t1 * 0 + 1) + background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), + np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) + foreground_matrix = np.argmax(background_foreground_matrix, axis=0) + segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix + segmentation_image = ants.matrix_to_images( + np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] - ################################ - # - # Process left/right in same network - # - ################################ + relabeled_image = ants.image_clone(segmentation_image) + for i in range(len(labels)): + relabeled_image[segmentation_image==i] = labels[i] + + foreground_probability_images = list() + for i in range(len(foreground_probability_images_left)): + foreground_probability_images.append(foreground_probability_images_left[i] + foreground_probability_images_right[i]) + + return_dict = None + if use_hierarchical_parcellation: + return_dict = {'segmentation_image' : relabeled_image, + 'probability_images' : probability_images, + 'medial_temporal_lobe_probability_image' : foreground_probability_images[0], + 'other_region_probability_image' : foreground_probability_images[1], + 'hippocampal_probability_image' : foreground_probability_images[2] + } + else: + return_dict = {'segmentation_image' : relabeled_image, + 'probability_images' : probability_images, + 'medial_temporal_lobe_probability_image' : foreground_probability_images[0] + } - if do_per_hemisphere == False: + return(return_dict) + + elif which_parcellation == "wip": + + use_contralaterality = True ################################ # - # Build model and load weights + # Preprocess images # ################################ - template_size = (160, 192, 160) + t1_preprocessed = t1 + t1_mask = None + t1_preprocessed_flipped = None + t1_template = ants.image_read(get_antsxnet_data("deepFlashTemplate2T1SkullStripped")) + template_transforms = None + if do_preprocessing: - unet_model = create_unet_model_3d((*template_size, 1), - number_of_outputs=len(labels), - number_of_layers=4, number_of_filters_at_base_layer=8, dropout_rate=0.0, - convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), - weight_decay=1e-5, additional_options=("attentionGating",)) + if verbose: + print("Preprocessing T1.") - if verbose: - print("DeepFlash: retrieving model weights.") + # Brain extraction + probability_mask = brain_extraction(t1_preprocessed, modality="t1", + antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) + t1_mask = ants.threshold_image(probability_mask, 0.5, 1, 1, 0) + t1_preprocessed = t1_preprocessed * t1_mask - weights_file_name = get_pretrained_network("deepFlash", antsxnet_cache_directory=antsxnet_cache_directory) - unet_model.load_weights(weights_file_name) + # Do bias correction + t1_preprocessed = ants.n4_bias_field_correction(t1_preprocessed, t1_mask, shrink_factor=4, verbose=verbose) + + # Warp to template + registration = ants.registration(fixed=t1_template, moving=t1_preprocessed, + type_of_transform="antsRegistrationSyNQuickRepro[a]", verbose=verbose) + template_transforms = dict(fwdtransforms=registration['fwdtransforms'], + invtransforms=registration['invtransforms']) + t1_preprocessed = registration['warpedmovout'] + + if use_contralaterality: + t1_preprocessed_array = t1_preprocessed.numpy() + t1_preprocessed_array_flipped = np.flip(t1_preprocessed_array, axis=0) + t1_preprocessed_flipped = ants.from_numpy(t1_preprocessed_array_flipped, + origin=t1_preprocessed.origin, + spacing=t1_preprocessed.spacing, + direction=t1_preprocessed.direction) + + + probability_images = list() + labels_left = list((104, 105, 106, 108, 109, 110, 114, 115, 126, 6001, 6003, 6008, 6009, 6010)) + labels_right = list((204, 205, 206, 208, 209, 210, 214, 215, 226, 7001, 7003, 7008, 7009, 7010)) + + # labels_left = list((103, 104, 105, 106, 108, 109, 110, 111, 112, 114, 115, 126, + # 6001, 6003, 6005, 6006, 6007, 6008, 6009, 6010, 6015)) + # labels_right = list((203, 204, 205, 206, 208, 209, 210, 211, 212, 214, 215, 226, + # 7001, 7003, 7005, 7006, 7007, 7008, 7009, 7010, 7015)) + labels = np.array(np.repeat(0, 1 + len(labels_left) + len(labels_right))) + labels[1::2] = labels_left + labels[2::2] = labels_right + image_size = (64, 64, 128) ################################ # - # Do prediction and normalize to native space + # Process left/right in split networks # ################################ - if verbose: - print("Prediction.") + ################################ + # + # Download spatial priors + # + ################################ - cropped_image = pad_or_crop_image_to_size(t1_preprocessed, template_size) + prior_labels_file_name_path = get_antsxnet_data("deepFlashTemplate2Labels", + antsxnet_cache_directory=antsxnet_cache_directory) + prior_labels = ants.image_read(prior_labels_file_name_path) - batchX = np.expand_dims(cropped_image.numpy(), axis=0) - batchX = np.expand_dims(batchX, axis=-1) - batchX = (batchX - batchX.mean()) / batchX.std() + priors_image_left_list = list() + for i in range(len(labels_left)): + prior_image = ants.threshold_image(prior_labels, labels_left[i], labels_left[i], 1, 0) + prior_image = ants.copy_image_info(t1_preprocessed, ants.smooth_image(prior_image, 1.0)) + priors_image_left_list.append(ants.smooth_image(prior_image, 1.0)) - predicted_data = unet_model.predict(batchX, verbose=verbose) + priors_image_right_list = list() + for i in range(len(labels_right)): + prior_image = ants.threshold_image(prior_labels, labels_right[i], labels_right[i], 1, 0) + prior_image = ants.copy_image_info(t1_preprocessed, ants.smooth_image(prior_image, 1.0)) + priors_image_right_list.append(ants.smooth_image(prior_image, 1.0)) - origin = cropped_image.origin - spacing = cropped_image.spacing - direction = cropped_image.direction + probability_images_left = list() + foreground_probability_images_left = list() + lower_bound_left = (114, 108, 82) + upper_bound_left = (178, 172, 210) + tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) + origin_left = tmp_cropped.origin - for i in range(len(labels)): - probability_image = \ - ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), - origin=origin, spacing=spacing, direction=direction) - if i > 0: - decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) - else: - decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) + spacing = tmp_cropped.spacing + direction = tmp_cropped.direction - if do_preprocessing: - probability_images.append(ants.apply_transforms(fixed=t1, - moving=decropped_image, - transformlist=t1_preprocessing['template_transforms']['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose)) - else: - probability_images.append(decropped_image) + t1_template_roi_left = ants.crop_indices(t1_template, lower_bound_left, upper_bound_left) + t1_template_roi_left = ((t1_template_roi_left - t1_template_roi_left.min()) / + (t1_template_roi_left.max() - t1_template_roi_left.min()) * 2.0 - 1.0) - ################################ - # - # Process left/right in split networks - # - ################################ + probability_images_right = list() + foreground_probability_images_right = list() + lower_bound_right = (50, 108, 82) + upper_bound_right = (114, 172, 210) + tmp_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) + origin_right = tmp_cropped.origin - else: + t1_template_roi_right = ants.crop_indices(t1_template, lower_bound_right, upper_bound_right) + t1_template_roi_right = ((t1_template_roi_right - t1_template_roi_right.min()) / + (t1_template_roi_right.max() - t1_template_roi_right.min()) * 2.0 - 1.0) ################################ # - # Left: download spatial priors + # Create model # ################################ - spatial_priors_left_file_name_path = get_antsxnet_data("priorDeepFlashLeftLabels", - antsxnet_cache_directory=antsxnet_cache_directory) - spatial_priors_left = ants.image_read(spatial_priors_left_file_name_path) - priors_image_left_list = ants.ndimage_to_list(spatial_priors_left) + channel_size = 1 + len(labels_left) + + number_of_classification_labels = 1 + len(labels_left) + + unet_model = create_unet_model_3d((*image_size, channel_size), + number_of_outputs=number_of_classification_labels, mode="classification", + number_of_filters=(32, 64, 96, 128, 256), + convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), + dropout_rate=0.0, weight_decay=0) + + penultimate_layer = unet_model.layers[-2].output + + # whole complex + output1 = Conv3D(filters=1, + kernel_size=(1, 1, 1), + activation='sigmoid', + kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) + + # hippocampus + output2 = Conv3D(filters=1, + kernel_size=(1, 1, 1), + activation='sigmoid', + kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) + + # amygdala + output3 = Conv3D(filters=1, + kernel_size=(1, 1, 1), + activation='sigmoid', + kernel_regularizer=regularizers.l2(0.0))(penultimate_layer) + + unet_model = Model(inputs=unet_model.input, outputs=[unet_model.output, output1, output2, output3]) ################################ # @@ -802,24 +763,7 @@ def deep_flash_deprecated(t1, # ################################ - template_size = (64, 96, 96) - labels_left = (0, 5, 7, 9, 11, 13, 15, 17) - channel_size = 1 + len(labels_left) - - number_of_filters = 16 - network_name = '' - if which_hemisphere_models == "old": - network_name = "deepFlashLeft16" - elif which_hemisphere_models == "new": - network_name = "deepFlashLeft16new" - else: - raise ValueError("network_name must be \"old\" or \"new\".") - - unet_model = create_unet_model_3d((*template_size, channel_size), - number_of_outputs = len(labels_left), - number_of_layers = 4, number_of_filters_at_base_layer = number_of_filters, dropout_rate = 0.0, - convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2), - weight_decay = 1e-5, additional_options=("attentionGating",)) + network_name = 'deepFlash2LeftT1Hierarchical' if verbose: print("DeepFlash: retrieving model weights (left).") @@ -835,51 +779,85 @@ def deep_flash_deprecated(t1, if verbose: print("Prediction (left).") - cropped_image = ants.crop_indices(t1_preprocessed, (30, 51, 0), (94, 147, 96)) - image_array = cropped_image.numpy() - image_array = (image_array - image_array.mean()) / image_array.std() + batchX = None + if use_contralaterality: + batchX = np.zeros((2, *image_size, channel_size)) + else: + batchX = np.zeros((1, *image_size, channel_size)) - batchX = np.zeros((1, *template_size, channel_size)) - batchX[0,:,:,:,0] = image_array + t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_left, upper_bound_left) + t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) + + batchX[0,:,:,:,0] = t1_cropped.numpy() + if use_contralaterality: + t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_left, upper_bound_left) + t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_left, 255, 64, False) + batchX[1,:,:,:,0] = t1_cropped.numpy() for i in range(len(priors_image_left_list)): - cropped_prior = ants.crop_indices(priors_image_left_list[i], (30, 51, 0), (94, 147, 96)) - batchX[0,:,:,:,i+1] = cropped_prior.numpy() + cropped_prior = ants.crop_indices(priors_image_left_list[i], lower_bound_left, upper_bound_left) + for j in range(batchX.shape[0]): + batchX[j,:,:,:,i + (channel_size - len(labels_left))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) - origin = cropped_image.origin - spacing = cropped_image.spacing - direction = cropped_image.direction - - probability_images_left = list() - for i in range(len(labels_left)): - probability_image = \ - ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), - origin=origin, spacing=spacing, direction=direction) - if i > 0: - decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) - else: - decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) + for i in range(1 + len(labels_left)): + for j in range(predicted_data[0].shape[0]): + probability_image = \ + ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), + origin=origin_left, spacing=spacing, direction=direction) + if i > 0: + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) + else: + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) + + if j == 1: # flipped + probability_array_flipped = np.flip(probability_image.numpy(), axis=0) + probability_image = ants.from_numpy(probability_array_flipped, + origin=probability_image.origin, spacing=probability_image.spacing, + direction=probability_image.direction) + + if do_preprocessing: + probability_image = ants.apply_transforms(fixed=t1, + moving=probability_image, + transformlist=template_transforms['invtransforms'], + whichtoinvert=[True], interpolator="linear", verbose=verbose) + + if j == 0: # not flipped + probability_images_left.append(probability_image) + else: # flipped + probability_images_right.append(probability_image) - if do_preprocessing: - probability_images_left.append(ants.apply_transforms(fixed=t1, - moving=decropped_image, - transformlist=t1_preprocessing['template_transforms']['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose)) - else: - probability_images_left.append(decropped_image) ################################ # - # Right: download spatial priors + # Left: do prediction of mtl, hippocampal, and ec regions and normalize to native space # ################################ - spatial_priors_right_file_name_path = get_antsxnet_data("priorDeepFlashRightLabels", - antsxnet_cache_directory=antsxnet_cache_directory) - spatial_priors_right = ants.image_read(spatial_priors_right_file_name_path) - priors_image_right_list = ants.ndimage_to_list(spatial_priors_right) + for i in range(1, len(predicted_data)): + for j in range(predicted_data[i].shape[0]): + probability_image = \ + ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), + origin=origin_left, spacing=spacing, direction=direction) + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) + + if j == 1: # flipped + probability_array_flipped = np.flip(probability_image.numpy(), axis=0) + probability_image = ants.from_numpy(probability_array_flipped, + origin=probability_image.origin, spacing=probability_image.spacing, + direction=probability_image.direction) + + if do_preprocessing: + probability_image = ants.apply_transforms(fixed=t1, + moving=probability_image, + transformlist=template_transforms['invtransforms'], + whichtoinvert=[True], interpolator="linear", verbose=verbose) + + if j == 0: # not flipped + foreground_probability_images_left.append(probability_image) + else: + foreground_probability_images_right.append(probability_image) ################################ # @@ -887,25 +865,10 @@ def deep_flash_deprecated(t1, # ################################ - template_size = (64, 96, 96) - labels_right = (0, 6, 8, 10, 12, 14, 16, 18) - channel_size = 1 + len(labels_right) - - number_of_filters = 16 - network_name = '' - if which_hemisphere_models == "old": - network_name = "deepFlashRight16" - elif which_hemisphere_models == "new": - network_name = "deepFlashRight16new" - else: - raise ValueError("network_name must be \"old\" or \"new\".") - - unet_model = create_unet_model_3d((*template_size, channel_size), - number_of_outputs = len(labels_right), - number_of_layers = 4, number_of_filters_at_base_layer = number_of_filters, dropout_rate = 0.0, - convolution_kernel_size = (3, 3, 3), deconvolution_kernel_size = (2, 2, 2), - weight_decay = 1e-5, additional_options=("attentionGating",)) + network_name = 'deepFlash2RightT1Hierarchical' + if verbose: + print("DeepFlash: retrieving model weights (right).") weights_file_name = get_pretrained_network(network_name, antsxnet_cache_directory=antsxnet_cache_directory) unet_model.load_weights(weights_file_name) @@ -918,40 +881,93 @@ def deep_flash_deprecated(t1, if verbose: print("Prediction (right).") - cropped_image = ants.crop_indices(t1_preprocessed, (88, 51, 0), (152, 147, 96)) - image_array = cropped_image.numpy() - image_array = (image_array - image_array.mean()) / image_array.std() + batchX = None + if use_contralaterality: + batchX = np.zeros((2, *image_size, channel_size)) + else: + batchX = np.zeros((1, *image_size, channel_size)) - batchX = np.zeros((1, *template_size, channel_size)) - batchX[0,:,:,:,0] = image_array + t1_cropped = ants.crop_indices(t1_preprocessed, lower_bound_right, upper_bound_right) + t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) + batchX[0,:,:,:,0] = t1_cropped.numpy() + if use_contralaterality: + t1_cropped = ants.crop_indices(t1_preprocessed_flipped, lower_bound_right, upper_bound_right) + if use_rank_intensity: + t1_cropped = ants.rank_intensity(t1_cropped) + else: + t1_cropped = ants.histogram_match_image(t1_cropped, t1_template_roi_right, 255, 64, False) + batchX[1,:,:,:,0] = t1_cropped.numpy() for i in range(len(priors_image_right_list)): - cropped_prior = ants.crop_indices(priors_image_right_list[i], (88, 51, 0), (152, 147, 96)) - batchX[0,:,:,:,i+1] = cropped_prior.numpy() + cropped_prior = ants.crop_indices(priors_image_right_list[i], lower_bound_right, upper_bound_right) + for j in range(batchX.shape[0]): + batchX[j,:,:,:,i + (channel_size - len(labels_right))] = cropped_prior.numpy() predicted_data = unet_model.predict(batchX, verbose=verbose) - origin = cropped_image.origin - spacing = cropped_image.spacing - direction = cropped_image.direction + for i in range(1 + len(labels_right)): + for j in range(predicted_data[0].shape[0]): + probability_image = \ + ants.from_numpy(np.squeeze(predicted_data[0][j, :, :, :, i]), + origin=origin_right, spacing=spacing, direction=direction) + if i > 0: + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) + else: + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) - probability_images_right = list() - for i in range(len(labels_right)): - probability_image = \ - ants.from_numpy(np.squeeze(predicted_data[0, :, :, :, i]), - origin=origin, spacing=spacing, direction=direction) - if i > 0: - decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0) - else: - decropped_image = ants.decrop_image(probability_image, t1_preprocessed * 0 + 1) + if j == 1: # flipped + probability_array_flipped = np.flip(probability_image.numpy(), axis=0) + probability_image = ants.from_numpy(probability_array_flipped, + origin=probability_image.origin, spacing=probability_image.spacing, + direction=probability_image.direction) + + if do_preprocessing: + probability_image = ants.apply_transforms(fixed=t1, + moving=probability_image, + transformlist=template_transforms['invtransforms'], + whichtoinvert=[True], interpolator="linear", verbose=verbose) + + if j == 0: # not flipped + if use_contralaterality: + probability_images_right[i] = (probability_images_right[i] + probability_image) / 2 + else: + probability_images_right.append(probability_image) + else: # flipped + probability_images_left[i] = (probability_images_left[i] + probability_image) / 2 - if do_preprocessing: - probability_images_right.append(ants.apply_transforms(fixed=t1, - moving=decropped_image, - transformlist=t1_preprocessing['template_transforms']['invtransforms'], - whichtoinvert=[True], interpolator="linear", verbose=verbose)) - else: - probability_images_right.append(decropped_image) + + ################################ + # + # Right: do prediction of mtl, hippocampal, and ec regions and normalize to native space + # + ################################ + + for i in range(1, len(predicted_data)): + for j in range(predicted_data[i].shape[0]): + probability_image = \ + ants.from_numpy(np.squeeze(predicted_data[i][j, :, :, :, 0]), + origin=origin_right, spacing=spacing, direction=direction) + probability_image = ants.decrop_image(probability_image, t1_preprocessed * 0) + + if j == 1: # flipped + probability_array_flipped = np.flip(probability_image.numpy(), axis=0) + probability_image = ants.from_numpy(probability_array_flipped, + origin=probability_image.origin, spacing=probability_image.spacing, + direction=probability_image.direction) + + if do_preprocessing: + probability_image = ants.apply_transforms(fixed=t1, + moving=probability_image, + transformlist=template_transforms['invtransforms'], + whichtoinvert=[True], interpolator="linear", verbose=verbose) + + if j == 0: # not flipped + if use_contralaterality: + foreground_probability_images_right[i-1] = (foreground_probability_images_right[i-1] + probability_image) / 2 + else: + foreground_probability_images_right.append(probability_image) + else: + foreground_probability_images_left[i-1] = (foreground_probability_images_left[i-1] + probability_image) / 2 ################################ # @@ -970,31 +986,40 @@ def deep_flash_deprecated(t1, probability_images.append(probability_images_left[i]) probability_images.append(probability_images_right[i]) - ################################ - # - # Convert probability images to segmentation - # - ################################ - - # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) - # segmentation_matrix = np.argmax(image_matrix, axis=0) - # segmentation_image = ants.matrix_to_images( - # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] + ################################ + # + # Convert probability images to segmentation + # + ################################ - image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], t1 * 0 + 1) - background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), - np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) - foreground_matrix = np.argmax(background_foreground_matrix, axis=0) - segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix - segmentation_image = ants.matrix_to_images( - np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] + # image_matrix = ants.image_list_to_matrix(probability_images, t1 * 0 + 1) + # segmentation_matrix = np.argmax(image_matrix, axis=0) + # segmentation_image = ants.matrix_to_images( + # np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] - relabeled_image = ants.image_clone(segmentation_image) - for i in range(len(labels)): - relabeled_image[segmentation_image==i] = labels[i] + image_matrix = ants.image_list_to_matrix(probability_images[1:(len(probability_images))], t1 * 0 + 1) + background_foreground_matrix = np.stack([ants.image_list_to_matrix([probability_images[0]], t1 * 0 + 1), + np.expand_dims(np.sum(image_matrix, axis=0), axis=0)]) + foreground_matrix = np.argmax(background_foreground_matrix, axis=0) + segmentation_matrix = (np.argmax(image_matrix, axis=0) + 1) * foreground_matrix + segmentation_image = ants.matrix_to_images( + np.expand_dims(segmentation_matrix, axis=0), t1 * 0 + 1)[0] - return_dict = {'segmentation_image' : relabeled_image, - 'probability_images' : probability_images} - return(return_dict) + relabeled_image = ants.image_clone(segmentation_image) + for i in range(len(labels)): + relabeled_image[segmentation_image==i] = labels[i] + foreground_probability_images = list() + for i in range(len(foreground_probability_images_left)): + foreground_probability_images.append(foreground_probability_images_left[i] + foreground_probability_images_right[i]) + return_dict = {'segmentation_image' : relabeled_image, + 'probability_images' : probability_images, + 'whole_probability_image' : foreground_probability_images[0], + 'hippocampal_probability_image' : foreground_probability_images[1], + 'amygdala_probability_image' : foreground_probability_images[2] + } + return(return_dict) + + else: + raise ValueError("Unrecognized parcellation.") diff --git a/antspynet/utilities/get_antsxnet_data.py b/antspynet/utilities/get_antsxnet_data.py index 2950124..254d425 100644 --- a/antspynet/utilities/get_antsxnet_data.py +++ b/antspynet/utilities/get_antsxnet_data.py @@ -43,6 +43,8 @@ def switch_data(argument): "deepFlashTemplateT1SkullStripped": "https://figshare.com/ndownloader/files/31339867", "deepFlashTemplateT2": "https://figshare.com/ndownloader/files/31207798", "deepFlashTemplateT2SkullStripped": "https://figshare.com/ndownloader/files/31339870", + "deepFlashTemplate2T1SkullStripped": "https://figshare.com/ndownloader/files/46461451", + "deepFlashTemplate2Labels": "https://figshare.com/ndownloader/files/46461415", "mprage_hippmapp3r": "https://ndownloader.figshare.com/files/24984689", "protonLobePriors": "https://figshare.com/ndownloader/files/30678452", "protonLungTemplate": "https://ndownloader.figshare.com/files/22707338", @@ -87,6 +89,8 @@ def switch_data(argument): "deepFlashTemplateT1SkullStripped", "deepFlashTemplateT2", "deepFlashTemplateT2SkullStripped", + "deepFlashTemplate2T1SkullStripped", + "deepFlashTemplate2Labels", "luna16LungPriors", "mprage_hippmapp3r", "priorDktLabels", diff --git a/antspynet/utilities/get_pretrained_network.py b/antspynet/utilities/get_pretrained_network.py index 60dcb5d..cb53c1e 100644 --- a/antspynet/utilities/get_pretrained_network.py +++ b/antspynet/utilities/get_pretrained_network.py @@ -83,6 +83,8 @@ def switch_networks(argument): "deepFlashRightT1Hierarchical_ri": "https://figshare.com/ndownloader/files/33198800", "deepFlashLeftBothHierarchical_ri": "https://figshare.com/ndownloader/files/33198803", "deepFlashRightBothHierarchical_ri": "https://figshare.com/ndownloader/files/33198809", + "deepFlash2LeftT1Hierarchical": "https://figshare.com/ndownloader/files/46461418", + "deepFlash2RightT1Hierarchical": "https://figshare.com/ndownloader/files/46461421", "deepFlash": "https://ndownloader.figshare.com/files/22933757", "deepFlashLeft8": "https://ndownloader.figshare.com/files/25441007", "deepFlashRight8": "https://ndownloader.figshare.com/files/25441004", @@ -205,6 +207,8 @@ def switch_networks(argument): "deepFlashRightT1Hierarchical_ri", "deepFlashLeftBothHierarchical_ri", "deepFlashRightBothHierarchical_ri", + "deepFlash2LeftT1Hierarchical", + "deepFlash2RightT1Hierarchical", "deepFlashLeft8", "deepFlashRight8", "deepFlashLeft16",