Skip to content

Commit

Permalink
adding ROI atlas options
Browse files Browse the repository at this point in the history
revise probablistic atlas scaling
  • Loading branch information
demidenm committed Sep 14, 2023
1 parent 7b84a05 commit 3d222e9
Showing 1 changed file with 60 additions and 14 deletions.
74 changes: 60 additions & 14 deletions pyrelimri/brain_icc.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import numpy as np
import nibabel as nib
from pandas import DataFrame
from sklearn.preprocessing import minmax_scale
from pyrelimri.icc import sumsq_icc
from nilearn import image
from nilearn.maskers import (NiftiMasker,NiftiMapsMasker, NiftiLabelsMasker)
from nilearn.datasets import (
fetch_atlas_aal,
fetch_atlas_allen_2011,
fetch_atlas_basc_multiscale_2015,
fetch_atlas_destrieux_2009,
fetch_atlas_difumo,
fetch_atlas_harvard_oxford,
fetch_atlas_juelich,
fetch_atlas_msdl,
fetch_atlas_pauli_2017,
fetch_atlas_schaefer_2018,
fetch_atlas_smith_2009,
fetch_atlas_talairach
)


Expand Down Expand Up @@ -106,22 +107,21 @@ def voxelwise_icc(multisession_list: str, mask: str, icc_type='icc_3'):
def setup_atlas(name_atlas: str, **kwargs):
default_params = {
'data_dir': None,
'resume': True,
'verbose': 0
}

# Dictionary mapping atlas names to their corresponding fetch functions
grab_atlas = {
'aal': fetch_atlas_aal,
'allen_2011': fetch_atlas_allen_2011,
'basc_multiscale_2015': fetch_atlas_basc_multiscale_2015,
'destrieux_2009': fetch_atlas_destrieux_2009,
'difumo': fetch_atlas_difumo,
'harvard_oxford': fetch_atlas_harvard_oxford,
'juelich': fetch_atlas_juelich,
'msdl': fetch_atlas_msdl,
'pauli_2017': fetch_atlas_pauli_2017,
'shaefer_2018': fetch_atlas_schaefer_2018,
'smith_2009': fetch_atlas_smith_2009
'smith_2009': fetch_atlas_smith_2009,
'talairach': fetch_atlas_talairach
}
atlas_grabbed = grab_atlas.get(name_atlas)

Expand All @@ -134,6 +134,33 @@ def setup_atlas(name_atlas: str, **kwargs):
f"OPTIONS:\n\t{', '.join(grab_atlas.keys())}")
return None

def prob_atlas_scale(nifti_map, estimate_array):
"""
Rescales a probabilistic 3D Nifti map to the range of estimated values.
:param nifti_map: Nifti1Image (3D)
The input Nifti image to be rescaled.
:param estimate_array: ndarray (1D)
A NumPy array containing the estimates used for scaling.
:return: Nifti1Image
Returns a 3D rescaled image based on the min/max of estimate_array.
"""

temp_img_array = nifti_map.get_fdata().flatten()
non_zero_mask = temp_img_array != 0

# Scale the non-zero values using minmax_scale from sklearn
scaled_values = minmax_scale(
temp_img_array[non_zero_mask],
feature_range=(min(estimate_array), max(estimate_array))
)
# New array w/ zeros & replace the non-zero values with the [new] scaled values
rescaled = np.zeros_like(temp_img_array, dtype=float)
rescaled[non_zero_mask] = scaled_values
new_img_shape = np.reshape(rescaled, nifti_map.shape)

return (image.new_img_like(nifti_map, new_img_shape))


def roi_icc(multisession_list: str, type_atlas: str,
atlas_dir: str, icc_type='icc_3', **kwargs):
Expand Down Expand Up @@ -187,7 +214,7 @@ def roi_icc(multisession_list: str, type_atlas: str,
# Grab atlas and mask images
# Atlases are either deterministic (3D) or probabilistic (4D). Try except to circumvent error
# grab/download atlas
atlas = setup_atlas(name_atlas=type_atlas, data_dir='/tmp/', **kwargs)
atlas = setup_atlas(name_atlas=type_atlas, data_dir=atlas_dir, **kwargs)
try:
atlas_dim = len(atlas.maps.shape)
except AttributeError:
Expand All @@ -197,12 +224,16 @@ def roi_icc(multisession_list: str, type_atlas: str,
masker = NiftiLabelsMasker(
labels_img=atlas.maps,
standardize=False,
resampling_target='data',
verbose=0
).fit()
elif atlas_dim == 4:
masker = NiftiMapsMasker(
maps_img=atlas.maps,
resampling_target='data'
allow_overlap=True,
standardize=False,
resampling_target='data',
verbose=0
).fit()
else:
raise ValueError("Atlas maps isn't 3D or 4D, so incompatible with Nifti[Labels/Maps]Masker() .")
Expand Down Expand Up @@ -243,14 +274,29 @@ def roi_icc(multisession_list: str, type_atlas: str,
'lower_bound': np.array(lowbound),
'upper_bound': np.array(upbound),
'ms_btwn': np.array(msbs),
'ms_wthn': np.array(msws),
'est_3d': masker.inverse_transform(np.array(est)),
'lower_bound_3d': masker.inverse_transform(np.array(lowbound)),
'upper_bound_3d': masker.inverse_transform(np.array(upbound)),
'ms_btwn_3d': masker.inverse_transform(np.array(msbs)),
'ms_wthn_3d': masker.inverse_transform(np.array(msws))
'ms_wthn': np.array(msws)
}

est_string = {"est_3d": est,
"lowbound_3d": lowbound, "upbound_3d": upbound,
"msbs_3d": msbs, "msws_3d": msws
}

if atlas_dim == 4:
for name, var in est_string.items():
est_img = masker.inverse_transform(np.array(var))
resample_img = prob_atlas_scale(est_img, np.array(var))
result_dict[name] = resample_img
else:
update_values = {
'est_3d': masker.inverse_transform(np.array(est)),
'lower_bound_3d': masker.inverse_transform(np.array(lowbound)),
'upper_bound_3d': masker.inverse_transform(np.array(upbound)),
'ms_btwn_3d': masker.inverse_transform(np.array(msbs)),
'ms_wthn_3d': masker.inverse_transform(np.array(msws))
}
result_dict.update(update_values)

return result_dict


0 comments on commit 3d222e9

Please sign in to comment.