Skip to content

Commit

Permalink
Merge pull request #738 from JulioAPeraza/parallel-decode
Browse files Browse the repository at this point in the history
Add parallelization option to `CorrelationDecoder` and `CorrelationDistributionDecoder`
  • Loading branch information
JulioAPeraza authored Aug 8, 2022
2 parents 6388514 + 79a8c95 commit a4a7d16
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 50 deletions.
2 changes: 1 addition & 1 deletion examples/01_datasets/02_download_neurosynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
version="1",
overwrite=False,
source="combined",
vocab="neuroquery7547",
vocab="neuroquery6308",
type="tfidf",
)
# Note that the files are saved to a new folder within "out_dir" named "neuroquery".
Expand Down
2 changes: 1 addition & 1 deletion nimare/decode/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _preprocess_input(self, dataset):
if not len(features):
raise Exception("No features identified in Dataset!")
elif len(features) < n_features_orig:
LGR.info(f"Retaining {len(features)}/({n_features_orig} features.")
LGR.info(f"Retaining {len(features)}/{n_features_orig} features.")

self.features_ = features

Expand Down
132 changes: 84 additions & 48 deletions nimare/decode/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from nilearn._utils import load_niimg
from nilearn.masking import apply_mask
from tqdm.auto import tqdm
Expand All @@ -13,7 +14,7 @@
from nimare.meta.cbma.base import CBMAEstimator
from nimare.meta.cbma.mkda import MKDAChi2
from nimare.stats import pearson
from nimare.utils import _check_type, _safe_transform
from nimare.utils import _check_ncores, _check_type, _safe_transform, tqdm_joblib

LGR = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,6 +111,10 @@ def gclda_decode_map(model, image, topic_priors=None, prior_weight=1):
class CorrelationDecoder(Decoder):
"""Decode an unthresholded image by correlating the image with meta-analytic maps.
.. versionchanged:: 0.0.13
* New parameter: `n_cores`. Number of cores to use for parallelization.
.. versionchanged:: 0.0.12
* Remove low-memory option in favor of sparse arrays.
Expand All @@ -126,6 +131,10 @@ class CorrelationDecoder(Decoder):
Meta-analysis estimator. Default is :class:`~nimare.meta.mkda.MKDAChi2`.
target_image : :obj:`str`
Name of meta-analysis results image to use for decoding.
n_cores : :obj:`int`, optional
Number of cores to use for parallelization.
If <=0, defaults to using all available cores.
Default is 1.
Warnings
--------
Expand All @@ -146,6 +155,7 @@ def __init__(
frequency_threshold=0.001,
meta_estimator=None,
target_image="z_desc-specificity",
n_cores=1,
):

if meta_estimator is None:
Expand All @@ -158,6 +168,7 @@ def __init__(
self.frequency_threshold = frequency_threshold
self.meta_estimator = meta_estimator
self.target_image = target_image
self.n_cores = _check_ncores(n_cores)

def _fit(self, dataset):
"""Generate feature-specific meta-analytic maps for dataset.
Expand All @@ -179,36 +190,43 @@ def _fit(self, dataset):
self.masker = dataset.masker

n_features = len(self.features_)
for i_feature, feature in enumerate(tqdm(self.features_, total=n_features)):
feature_ids = dataset.get_studies_by_label(
labels=[feature],
label_threshold=self.frequency_threshold,
)
# Limit selected studies to studies with valid data
feature_ids = sorted(list(set(feature_ids).intersection(self.inputs_["id"])))

# Create the reduced Dataset
feature_dset = dataset.slice(feature_ids)

# Check if the meta method is a pairwise estimator
# This seems like a somewhat inelegant solution
if "dataset2" in inspect.getfullargspec(self.meta_estimator.fit).args:
nonfeature_ids = sorted(list(set(self.inputs_["id"]) - set(feature_ids)))
nonfeature_dset = dataset.slice(nonfeature_ids)
meta_results = self.meta_estimator.fit(feature_dset, nonfeature_dset)
else:
meta_results = self.meta_estimator.fit(feature_dset)

feature_data = meta_results.get_map(
self.target_image,
return_type="array",
with tqdm_joblib(tqdm(total=n_features)):
images_, feature_idx = zip(
*Parallel(n_jobs=self.n_cores)(
delayed(self._run_fit)(i_feature, feature, dataset)
for i_feature, feature in enumerate(self.features_)
)
)
if i_feature == 0:
images_ = np.zeros((len(self.features_), len(feature_data)), feature_data.dtype)
# Convert to an array and sort the images_ array based on the feature index.
images_ = np.array(images_)[np.array(feature_idx)]
self.images_ = images_

images_[i_feature, :] = feature_data
def _run_fit(self, i_feature, feature, dataset):
feature_ids = dataset.get_studies_by_label(
labels=[feature],
label_threshold=self.frequency_threshold,
)
# Limit selected studies to studies with valid data
feature_ids = sorted(list(set(feature_ids).intersection(self.inputs_["id"])))

# Create the reduced Dataset
feature_dset = dataset.slice(feature_ids)

# Check if the meta method is a pairwise estimator
# This seems like a somewhat inelegant solution
if "dataset2" in inspect.getfullargspec(self.meta_estimator.fit).args:
nonfeature_ids = sorted(list(set(self.inputs_["id"]) - set(feature_ids)))
nonfeature_dset = dataset.slice(nonfeature_ids)
meta_results = self.meta_estimator.fit(feature_dset, nonfeature_dset)
else:
meta_results = self.meta_estimator.fit(feature_dset)

self.images_ = images_
feature_data = meta_results.get_map(
self.target_image,
return_type="array",
)

return feature_data, i_feature

def transform(self, img):
"""Correlate target image with each feature-specific meta-analytic map.
Expand All @@ -233,6 +251,10 @@ def transform(self, img):
class CorrelationDistributionDecoder(Decoder):
"""Decode an unthresholded image by correlating the image with study-wise images.
.. versionchanged:: 0.0.13
* New parameter: `n_cores`. Number of cores to use for parallelization.
Parameters
----------
feature_group : :obj:`str`, optional
Expand All @@ -243,6 +265,10 @@ class CorrelationDistributionDecoder(Decoder):
Frequency threshold. Default is 0.001.
target_image : {'z', 'con'}, optional
Name of meta-analysis results image to use for decoding. Default is 'z'.
n_cores : :obj:`int`, optional
Number of cores to use for parallelization.
If <=0, defaults to using all available cores.
Default is 1.
Warnings
--------
Expand All @@ -261,11 +287,13 @@ def __init__(
features=None,
frequency_threshold=0.001,
target_image="z",
n_cores=1,
):
self.feature_group = feature_group
self.features = features
self.frequency_threshold = frequency_threshold
self._required_inputs["images"] = ("image", target_image)
self.n_cores = _check_ncores(n_cores)

def _fit(self, dataset):
"""Collect sets of maps from the Dataset corresponding to each requested feature.
Expand All @@ -286,31 +314,39 @@ def _fit(self, dataset):
"""
self.masker = dataset.masker

images_ = {}
for feature in self.features_:
feature_ids = dataset.get_studies_by_label(
labels=[feature], label_threshold=self.frequency_threshold
)
selected_ids = sorted(list(set(feature_ids).intersection(self.inputs_["id"])))
selected_id_idx = [
i_id for i_id, id_ in enumerate(self.inputs_["id"]) if id_ in selected_ids
]
test_imgs = [
img for i_img, img in enumerate(self.inputs_["images"]) if i_img in selected_id_idx
]
if len(test_imgs):
feature_arr = _safe_transform(
test_imgs,
self.masker,
memfile=None,
n_features = len(self.features_)
with tqdm_joblib(tqdm(total=n_features)):
images_ = dict(
Parallel(n_jobs=self.n_cores)(
delayed(self._run_fit)(feature, dataset) for feature in self.features_
)
images_[feature] = feature_arr
else:
LGR.info(f"Skipping feature '{feature}'. No images found.")
)

# reduce features again
self.features_ = [f for f in self.features_ if f in images_.keys()]
self.images_ = images_

def _run_fit(self, feature, dataset):
feature_ids = dataset.get_studies_by_label(
labels=[feature], label_threshold=self.frequency_threshold
)
selected_ids = sorted(list(set(feature_ids).intersection(self.inputs_["id"])))
selected_id_idx = [
i_id for i_id, id_ in enumerate(self.inputs_["id"]) if id_ in selected_ids
]
test_imgs = [
img for i_img, img in enumerate(self.inputs_["images"]) if i_img in selected_id_idx
]
if len(test_imgs):
feature_arr = _safe_transform(
test_imgs,
self.masker,
memfile=None,
)
return feature, feature_arr
else:
LGR.info(f"Skipping feature '{feature}'. No images found.")

def transform(self, img):
"""Correlate target image with each map associated with each feature.
Expand Down

0 comments on commit a4a7d16

Please sign in to comment.