Skip to content

Commit

Permalink
Merge pull request #9 from zugmana/main
Browse files Browse the repository at this point in the history
Edits to seedmap and fmriprep notebooks
  • Loading branch information
Shotgunosine authored Oct 2, 2024
2 parents 4981a32 + e4fc76b commit 6f92ede
Show file tree
Hide file tree
Showing 4 changed files with 5,076 additions and 1,949 deletions.
114 changes: 99 additions & 15 deletions contarg/cli/run_seedmap.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def seedmap():
default=None,
help="Run identifier to construct output file names. Only pass if you are not concatenating runs.",
)
@click.option(
'--confounds-strategy',
type=str,
multiple=True,
help='Confounds strategy (can be used multiple times) - for documentation see nilearn.interfaces.fmriprep.load_confounds'
)
@click.option(
'--extra-args',
type=str,
multiple=True,
help='Additional keyword arguments for controlling confounds selection in the form key=value.'
)
def subjectmap(
bold_path,
mask_path,
Expand All @@ -110,16 +122,18 @@ def subjectmap(
subject,
session,
run,
confounds_strategy,
extra_args
):
"""
Get the voxel wise connectivity map of a passed bold image with the reference roi.
If multiple bold_paths is passed,they'll be concatenated. Runs global signal regression.
If multiple bold_paths is passed,they'll be concatenated. Runs global signal regression or alternate confounds regression.
Output will be masked by grey matter mask and stimroi.
"""
bold_paths = bold_path
derivatives_dir = Path(derivatives_dir)
roi_dir = Path(resource_filename("contarg", "data/rois"))

kwargs = dict(arg.split('=') for arg in extra_args) # turn this list into a dictionary.
if refroi_name in ["SGCsphere", "bilateralSGCspheres"]:
ref_roi_2mm_path = (
roi_dir / f"{refroi_name}_space-MNI152NLin6Asym_res-02.nii.gz"
Expand Down Expand Up @@ -158,10 +172,12 @@ def subjectmap(
ref_vox_img = get_ref_vox_con(
bold_paths,
mask_path,
ref_vox_con_path,
ref_roi_2mm_path,
t_r,
out_path=ref_vox_con_path,
smoothing_fwhm=smoothing_fwhm,
confounds_strategy = confounds_strategy,
**kwargs
)
# mask ref_vox_img
subj_mask = nl.image.load_img(mask_path)
Expand Down Expand Up @@ -216,9 +232,10 @@ def groupmap(contarg_dir, session, run):
subjmaps = sorted(contarg_dir.rglob(glob_str))

tmp = nl.image.load_img(subjmaps[0])
mapsum = np.zeros_like(tmp, dtype=float)
mapsum = np.zeros_like(tmp.get_fdata().squeeze(), dtype=float)

for subjmap in subjmaps:
print(subjmap)
subjimg = nl.image.load_img(subjmap)
mapsum += subjimg.get_fdata().squeeze()
del subjimg
Expand All @@ -243,7 +260,8 @@ def groupmap(contarg_dir, session, run):
type=click.Path(),
help="Path to pybids database file (expects version 0.15.2), "
"if one does not exist here, it will be created.",
required=True,
required=False,
default=None
)
@click.option(
"--run-name",
Expand Down Expand Up @@ -292,7 +310,7 @@ def groupmap(contarg_dir, session, run):
"--ndummy",
"n_dummy",
type=int,
default=0,
default=None,
help="Number of dummy scans at the beginning of the functional time series",
)
@click.option(
Expand All @@ -304,7 +322,7 @@ def groupmap(contarg_dir, session, run):
)
@click.option(
"--target-method",
type=click.Choice(["classic", "cluster"]),
type=click.Choice(["classic", "cluster","None"]),
default="cluster",
show_default=True,
help="How to pick a target coordinate from the seedmap weighted connectivity.",
Expand All @@ -320,7 +338,8 @@ def groupmap(contarg_dir, session, run):
"--percentile",
type=float,
help="All values more extreme than percentile will be kept for clustering",
required=True,
required=False,
default=10
)
@click.option(
"--subject",
Expand Down Expand Up @@ -350,9 +369,35 @@ def groupmap(contarg_dir, session, run):
show_default=True,
help="Number of jobs to run in parallel to find targets",
)
@click.option(
"--fmriprepdir",
type=str,
default=None,
help="Path to fmriprep direcotry if not in standard ./derivatives/fmriprep/.",
)
@click.option(
'--confounds-strategy',
type=str,
multiple=True,
help='Confounds strategy (can be used multiple times) - for documentation see nilearn.interfaces.fmriprep.load_confounds'
)
@click.option(
'--extra-args',
type=str,
multiple=True,
help='Additional keyword arguments for controlling confounds selection in the form key=value.'
)
@click.option(
'--concat-level',
type=str,
multiple=False,
default=None,
help='level to concatenate. Choose subject or sessions. Default=None.'
)
def run(
bids_dir,
derivatives_dir,
fmriprepdir,
database_file,
run_name,
stimroi_name,
Expand All @@ -370,18 +415,27 @@ def run(
run,
echo,
njobs,
confounds_strategy,
extra_args,
concat_level
):
# TODO: add code for concatenating runs
bids_dir = Path(bids_dir)
derivatives_dir = Path(derivatives_dir)
database_path = Path(database_file)
if not database_file:
database_path = None
else :
database_path = Path(database_file)
seedmap_path = Path(seedmap_path)
roi_dir = Path(resource_filename("contarg", "data/rois"))
if not fmriprepdir :
fmriprepdir = derivatives_dir / "fmriprep"
layout = BIDSLayout(
bids_dir,
database_path=database_path,
derivatives=derivatives_dir / "fmriprep",
derivatives=fmriprepdir,
)
kwargs = dict(arg.split('=') for arg in extra_args)# Turn list into dictionary.
if run_name is not None:
targeting_dir = derivatives_dir / "contarg" / "seedmap" / run_name
else:
Expand Down Expand Up @@ -426,6 +480,24 @@ def run(
rest_paths["bold_path"] = [bb.path for bb in bolds]
if "session" not in rest_paths.columns:
rest_paths["session"] = None
# if concatenate runs (add a handler in click)
if concat_level is None:
#continue
print('concat_level set to None')
else:
if concat_level == 'session':
concat_gb = ['subject', 'session']
elif concat_level == 'subject':
concat_gb = ['subject']
else:
raise NotImplementedError("Only concatenating on subject or subject and session are supported")
new_rest_paths = []
for _, df in rest_paths.groupby(concat_gb):
new_row = df.iloc[0]
new_row['bold_path'] = list(df.bold_path.values)
new_rest_paths.append(new_row)
rest_paths = pd.DataFrame(new_rest_paths)

# add boldref
rest_paths["boldref"] = rest_paths.entities.apply(
lambda ee: layout.get(
Expand Down Expand Up @@ -459,9 +531,9 @@ def run(
space=None,
)
)
assert rest_paths.T1w.apply(lambda x: len(x) == 1).all()
rest_paths["T1w"] = rest_paths.T1w.apply(lambda x: x[0])

rest_paths["T1w"] = rest_paths.T1w.apply(lambda x: [x[0]])
assert rest_paths.T1w.apply(lambda x: isinstance(x, str) or (isinstance(x, list) and len(x) == 1)).all()
# add mnito t1w path
rest_paths["mnitoT1w"] = rest_paths.entities.apply(
lambda ee: layout.get(
Expand All @@ -472,10 +544,13 @@ def run(
suffix="xfm",
to="T1w",
**{"from": "MNI152NLin6Asym"},
**({"session": ee["session"]} if "session" in ee and ee["session"] is not None else {})
)
)

rest_paths["mnitoT1w"] = rest_paths.mnitoT1w.apply(lambda x: [x[0]])
assert rest_paths.mnitoT1w.apply(lambda x: len(x) == 1).all()
rest_paths["mnitoT1w"] = rest_paths.mnitoT1w.apply(lambda x: x[0])


# add confounds path
rest_paths["confounds"] = rest_paths.entities.apply(
Expand All @@ -488,8 +563,10 @@ def run(
extension=".tsv",
suffix="timeseries",
desc="confounds",
)
**({"session": ee["session"]} if "session" in ee and ee["session"] is not None else {})
)
)

assert rest_paths.confounds.apply(lambda x: len(x) == 1).all()
rest_paths["confounds"] = rest_paths.confounds.apply(lambda x: x[0])

Expand Down Expand Up @@ -548,6 +625,8 @@ def run(

if target_method == "cluster":
desc = f"{connectivity}.{target_method}.p{percentile}"
elif target_method == "None" :
desc = "rawCorr"
else:
desc = f"{target_method}"

Expand Down Expand Up @@ -653,6 +732,7 @@ def run(
tr=t_r,
out_path=row[f"{desc}_seedmap_correlation"],
smoothing_fwhm=smoothing_fwhm,
confound_strategy=confounds_strategy
)
if target_method == "cluster":
clust_img = cluster(
Expand All @@ -669,6 +749,10 @@ def run(
target_idx = np.where(ref_vox_dat == ref_vox_dat.min())
target_idx = np.array([list(rr) + [1] for rr in zip(*target_idx)])
target_coords = np.matmul(ref_vox_img.affine, target_idx.T).T[:, :3]
elif target_method == "None":
print(f"No target method selected. Raw Average correlation map saved in {row[f'{desc}_seedmap_correlation']}. Done")
#print(row)
return
else:
raise NotImplementedError(
f"Target method {target_method} is not implemented."
Expand Down
49 changes: 42 additions & 7 deletions contarg/seedmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,41 @@
import numpy as np
import nilearn as nl
from nilearn import image, masking, maskers, plotting, datasets, connectome
from nilearn.interfaces.fmriprep import load_confounds
from pathlib import Path
from .utils import iterable


def get_nlearn_confounds(subj_mask,bold_path,confounds_strategy=None, custom_gsr=True, ndummy=None, **kwargs):
if confounds_strategy :
confounds,sample_mask = load_confounds(bold_path,strategy=confounds_strategy, **kwargs)
if 'global_signal' in confounds_strategy:
return confounds.to_numpy(),sample_mask # This will override custom_gsr and ndummy
elif custom_gsr:
gs_masker = nl.maskers.NiftiMasker(mask_img=subj_mask)
confounds.loc[:,'gs'] = gs_masker.fit_transform(bold_path).mean(1).reshape(-1, 1)
if ndummy :
sample_mask = confounds.index[ndummy:].to_numpy() # This will override the sample mask
return confounds.to_numpy(),sample_mask
else :
return confounds.to_numpy(),sample_mask
else :
gs_masker = nl.maskers.NiftiMasker(mask_img=subj_mask)
confounds = gs_masker.fit_transform(bold_path).mean(1).reshape(-1, 1)
_,sample_mask = load_confounds(bold_path,strategy=["non_steady_state"]) # This is just to get the non-steady in the same format.
if ndummy :
sample_mask = confounds.index[ndummy:].to_numpy() # This will override the sample mask

return confounds, sample_mask

def get_ref_vox_con(
bold_path, mask_path, refroi_path, tr, out_path=None, smoothing_fwhm=4.0
bold_path,
mask_path,
refroi_path,
tr,
out_path=None,
smoothing_fwhm=4.0,
confounds_strategy=None,
**kwargs
):
"""
Get the voxel wise connectivity map of a passed bold image with the reference roi.
Expand All @@ -26,6 +55,7 @@ def get_ref_vox_con(
Path to write connectivity map to
smoothing_fwhm : float default 4.0
FWHM of gaussian smoothing to be applied
confounds :
"""
if not iterable(bold_path):
Expand All @@ -34,9 +64,10 @@ def get_ref_vox_con(
bold_paths = bold_path

subj_mask = nl.image.load_img(mask_path)
print(refroi_path)
ref_mask = nl.image.load_img(refroi_path)
masked_ref_mask = nl.masking.apply_mask(ref_mask, subj_mask)
gs_masker = nl.maskers.NiftiMasker(mask_img=subj_mask)
#gs_masker = nl.maskers.NiftiMasker(mask_img=subj_mask)
subj_masker = nl.maskers.NiftiMasker(
mask_img=subj_mask,
low_pass=0.1,
Expand All @@ -48,8 +79,9 @@ def get_ref_vox_con(
# process each run
clean_tses = []
for bold_path in bold_paths:
gs = gs_masker.fit_transform(bold_path).mean(1).reshape(-1, 1)
cleaned = subj_masker.fit_transform(bold_path, confounds=gs)
#gs = gs_masker.fit_transform(bold_path).mean(1).reshape(-1, 1)
confounds,sample_mask = get_nlearn_confounds(subj_mask,bold_path,confounds_strategy=confounds_strategy, custom_gsr=True, **kwargs)
cleaned = subj_masker.fit_transform(bold_path, confounds=confounds, sample_mask=sample_mask) # Add the confounds from nilearn.interfaces.fmriprep.load_confounds
clean_tses.append(cleaned)
cat_clean_tses = np.vstack(clean_tses)
ref_ts = cat_clean_tses[:, masked_ref_mask.astype(bool)].mean(1).reshape(-1, 1)
Expand All @@ -70,6 +102,8 @@ def get_seedmap_vox_con(
tr,
out_path=None,
smoothing_fwhm=4.0,
confounds_strategy=None,
**kwargs
):
"""
Get the representative time series of a passed bold image based on a seedmap.
Expand Down Expand Up @@ -113,8 +147,9 @@ def get_seedmap_vox_con(
# process each run
clean_tses = []
for bold_path in bold_paths:
gs = gs_masker.fit_transform(bold_path).mean(1).reshape(-1, 1)
cleaned = subj_masker.fit_transform(bold_path, confounds=gs)[n_dummy:]
#gs = gs_masker.fit_transform(bold_path).mean(1).reshape(-1, 1)
confounds,sample_mask = get_nlearn_confounds(subj_mask,bold_path,confounds_strategy=confounds_strategy, ndummy=n_dummy,custom_gsr=True, **kwargs)
cleaned = subj_masker.fit_transform(bold_path, confounds=confounds, sample_mask=sample_mask) #[n_dummy:] # Add the confounds from nilearn.interfaces.fmriprep.load_confounds Should already remove the dummys if sample_mask works.
clean_tses.append(cleaned)
cat_clean_tses = np.vstack(clean_tses)
seedmap_ts = np.average(cat_clean_tses, axis=1, weights=masked_seedmap)
Expand Down
Loading

0 comments on commit 6f92ede

Please sign in to comment.