Skip to content

Commit

Permalink
[ENH] Improve CAPS reader (#760)
Browse files Browse the repository at this point in the history
* Extract relevant work from PETVolume PR

* fix rebase mistake

* Refactor - test - unstable

* Add some docs

* Add some basic tests

* Finish testing query

* forgot to add error message in caps_reader...

* Add unit tests for bids_reader and caps_reader functions

* Simplify CAPSDataGrabber classes

* Simplify code

* Apply suggestions from code review

Co-authored-by: Ghislain Vaillant <[email protected]>

* Dict->dict, List[str]->list

* Simplify Query constructor

* Add tests and docs for aggregator

* remove value trait

* fix some design issues

Co-authored-by: Ghislain Vaillant <[email protected]>
  • Loading branch information
NicolasGensollen and ghisvail authored Oct 17, 2022
1 parent ec53b47 commit a124170
Show file tree
Hide file tree
Showing 9 changed files with 762 additions and 128 deletions.
39 changes: 24 additions & 15 deletions clinica/pydra/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
from os import PathLike
from typing import Callable

from pydra import Workflow
from pydra.engine.core import TaskBase
Expand Down Expand Up @@ -41,7 +40,7 @@ def run_wrapper(
def add_input_reading_task(
pipeline: Workflow,
core_workflow: Workflow,
query_maker: Callable,
query_type: str,
reader: TaskBase,
) -> Workflow:
"""Configure and add the reading tasks of input workflow.
Expand All @@ -51,14 +50,13 @@ def add_input_reading_task(
pipeline : Workflow
The main Workflow to which the readers should be added.
core_inputs : dict
The inputs specified by the core workflow. This defines
what the reader workflow should read and how it should
connect to the core Workflow.
core_workflow : Workflow
The core workflow. It defines, through its inputs,
what the reader workflow should read and these two
workflows should connect.
query_maker : Callable
Function responsible for parsing the core_inputs into a
proper query.
query_type : {"bids", "caps_file", "caps_group"}
The type of query that should be run.
reader : TaskBase
Task responsible for reading data.
Expand All @@ -68,7 +66,9 @@ def add_input_reading_task(
pipeline : Workflow
The main Workflow with readers added to it.
"""
query = query_maker(pu.list_workflow_inputs(core_workflow))
from clinica.pydra.query import query_factory

query = query_factory(pu.list_workflow_inputs(core_workflow), query_type=query_type)
if len(query) == 0:
return pipeline
input_dir = "bids_dir" if "bids" in reader.__name__ else "caps_dir"
Expand Down Expand Up @@ -98,14 +98,21 @@ def add_input_reading_task(

add_input_reading_task_bids = functools.partial(
add_input_reading_task,
query_maker=pu.bids_query,
query_type="bids",
reader=bids_reader,
)


add_input_reading_task_caps = functools.partial(
add_input_reading_task_caps_file = functools.partial(
add_input_reading_task,
query_type="caps_file",
reader=caps_reader,
)


add_input_reading_task_caps_group = functools.partial(
add_input_reading_task,
query_maker=pu.caps_query,
query_type="caps_group",
reader=caps_reader,
)

Expand All @@ -116,7 +123,8 @@ def build_input_workflow(pipeline: Workflow, core_workflow: Workflow) -> Workflo
For now, the input workflow is responsible for:
- reading BIDS data
- reading CAPS data
- reading CAPS data with clinica_file_reader
- reading CAPS data with clinica_group_reader
Parameters
----------
Expand All @@ -132,7 +140,8 @@ def build_input_workflow(pipeline: Workflow, core_workflow: Workflow) -> Workflo
The pipeline with the input workflow.
"""
pipeline = add_input_reading_task_bids(pipeline, core_workflow)
pipeline = add_input_reading_task_caps(pipeline, core_workflow)
pipeline = add_input_reading_task_caps_file(pipeline, core_workflow)
pipeline = add_input_reading_task_caps_group(pipeline, core_workflow)
return pipeline


Expand Down
68 changes: 0 additions & 68 deletions clinica/pydra/engine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,74 +48,6 @@ def list_workflow_inputs(wf: Workflow) -> dict:
return inputs


def bids_query(raw_query: dict) -> dict:
"""Parse a raw BIDS query dictionary and return a properly
formatted query dictionary that is compatible with the
`BIDSDataGrabber` interface.
Parameters
----------
raw_query : dict
The raw BIDS query as a dictionary. This may contain data
that is not supported by the `BIDSDataGrabber`.
Returns
-------
query : dict
The formatted query dictionary compatible with `BIDSDataGrabber`.
"""
query = {}
bids_default_queries = {
"T1w": {"datatype": "anat", "suffix": "T1w", "extension": [".nii.gz"]}
}
for k, q in raw_query.items():
if k in bids_default_queries:
query[k] = {**bids_default_queries[k], **q}
return query


def caps_query(raw_query: dict) -> dict:
"""Parse a raw CAPS query dictionary and return a properly
formatted query dictionary that is compatible with the
`CAPSDataGrabber` interface.
Parameters
----------
raw_query : dict
The raw CAPS query as a dictionary. This may contain data
that is not supported by the `CAPSDataGrabber`.
Returns
-------
query : dict
The formatted query dictionary compatible with `CAPSDataGrabber`.
"""
from clinica.utils.input_files import (
t1_volume_deformation_to_template,
t1_volume_final_group_template,
t1_volume_native_tpm,
t1_volume_native_tpm_in_mni,
)

query = {}
caps_keys_available_file_reader = {
"mask_tissues": t1_volume_native_tpm_in_mni,
"flow_fields": t1_volume_deformation_to_template,
"pvc_mask_tissues": t1_volume_native_tpm,
}
caps_keys_available_group_reader = {
"dartel_template": t1_volume_final_group_template,
}
for k, v in raw_query.items():
if k in caps_keys_available_file_reader:
query[k] = caps_keys_available_file_reader[k](**v)
query[k]["reader"] = "file"
elif k in caps_keys_available_group_reader:
query[k] = caps_keys_available_group_reader[k](**v)
query[k]["reader"] = "group"
return query


def run(wf: Workflow) -> str:
"""Execute a Pydra workflow
Expand Down
83 changes: 59 additions & 24 deletions clinica/pydra/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# from pathlib import PurePath
import abc
from os import PathLike
from typing import Dict, List

import nipype.interfaces.io as nio
import pydra
from nipype.interfaces.base import Directory, DynamicTraitedSpec, Str, isdefined, traits
from nipype.interfaces.io import IOBase, add_traits
from pydra.tasks.nipype1.utils import Nipype1Task

from clinica.pydra.query import BIDSQuery, CAPSFileQuery, CAPSGroupQuery, CAPSQuery


class CAPSDataGrabberInputSpec(DynamicTraitedSpec):
base_dir = Directory(exists=True, desc="Path to CAPS Directory.", mandatory=True)
output_query = traits.Dict(
key_trait=Str, value_trait=traits.Dict, desc="Queries for outfield outputs"
)
output_query = traits.Dict(key_trait=Str, desc="Queries for outfield outputs")
raise_on_empty = traits.Bool(
True,
usedefault=True,
Expand All @@ -32,35 +34,59 @@ def __init__(self, **kwargs):
# used for mandatory inputs check
undefined_traits = {}
self.inputs.trait_set(trait_change_notify=False, **undefined_traits)
self.sessions = None
self.subjects = None

def _list_outputs(self):
from clinica.utils.inputs import clinica_file_reader, clinica_group_reader
from clinica.utils.participant import get_subject_session_list

sessions, subjects = get_subject_session_list(
self.inputs.base_dir,
is_bids_dir=False,
)
query = {}
for k, q in self.inputs.output_query.items():
reader = q.pop("reader")
if reader == "file":
query[k] = clinica_file_reader(
subjects,
sessions,
self.inputs.base_dir,
q,
)
elif reader == "group":
query[k] = clinica_group_reader(self.inputs.base_dir, q)
self.sessions = sessions
self.subjects = subjects

output_query = {}
for k, query in self.inputs.output_query.items():
if isinstance(query, list):
temp = [self._execute_single_query(q) for q in query]
if len(temp) != len(self.subjects) and len(temp[0]) == len(
self.subjects
):
temp = [list(i) for i in zip(*temp)] # Transpose
output_query[k] = temp
else:
raise ValueError(f"Unknown reader {reader}.")
return query
output_query[k] = self._execute_single_query(query)
return output_query

@abc.abstractmethod
def _execute_single_query(self, query: dict) -> list:
pass

def _add_output_traits(self, base):
return add_traits(base, list(self.inputs.output_query.keys()))


class CAPSFileDataGrabber(CAPSDataGrabber):
def _execute_single_query(self, query: Dict) -> List[str]:
from clinica.utils.inputs import clinica_file_reader

return clinica_file_reader(
self.subjects,
self.sessions,
self.inputs.base_dir,
query,
)[0]


class CAPSGroupDataGrabber(CAPSDataGrabber):
def _execute_single_query(self, query: Dict) -> List[str]:
from clinica.utils.inputs import clinica_group_reader

return clinica_group_reader(self.inputs.base_dir, query)


@pydra.mark.task
@pydra.mark.annotate({"return": {"output_file": str}})
def bids_writer(output_dir: PathLike, output_file: PathLike) -> str:
Expand All @@ -84,11 +110,11 @@ def bids_writer(output_dir: PathLike, output_file: PathLike) -> str:
return output_file


def bids_reader(query: dict, input_dir: PathLike):
def bids_reader(query: BIDSQuery, input_dir: PathLike):
"""
Parameters
----------
query : dict
query : BIDSQuery
Input to BIDSDataGrabber (c.f https://nipype.readthedocs.io/en/latest/api/generated/nipype.interfaces.io.html#bidsdatagrabber)
input_dir : PathLike
The BIDS input directory.
Expand All @@ -98,7 +124,7 @@ def bids_reader(query: dict, input_dir: PathLike):
Nipype1Task
The task used for reading files from BIDS.
"""
bids_data_grabber = nio.BIDSDataGrabber(output_query=query)
bids_data_grabber = nio.BIDSDataGrabber(output_query=query.query)
bids_reader_task = Nipype1Task(
name="bids_reader_task",
interface=bids_data_grabber,
Expand All @@ -107,11 +133,11 @@ def bids_reader(query: dict, input_dir: PathLike):
return bids_reader_task


def caps_reader(query: dict, input_dir: PathLike):
def caps_reader(query: CAPSQuery, input_dir: PathLike):
"""
Parameters
----------
query : dict
query : CAPSQuery
Input to CAPSDataGrabber.
input_dir : PathLike
The CAPS input directory.
Expand All @@ -121,7 +147,16 @@ def caps_reader(query: dict, input_dir: PathLike):
Nipype1Task
The task used for reading files from CAPS.
"""
caps_data_grabber = CAPSDataGrabber(output_query=query)
if isinstance(query, CAPSFileQuery):
grabber = CAPSFileDataGrabber
elif isinstance(query, CAPSGroupQuery):
grabber = CAPSGroupDataGrabber
else:
raise TypeError(
f"caps_reader received an unexpected query type {type(query)}. "
"Supported types are: CAPSFileQuery and CAPSGroupQuery."
)
caps_data_grabber = grabber(output_query=query.query)
caps_reader_task = Nipype1Task(
name="caps_reader_task",
interface=caps_data_grabber,
Expand Down
Loading

0 comments on commit a124170

Please sign in to comment.