diff --git a/clinica/iotools/converters/adni_to_bids/adni_json.py b/clinica/iotools/converters/adni_to_bids/adni_json.py index b14b51335a..4dd35e5b4f 100644 --- a/clinica/iotools/converters/adni_to_bids/adni_json.py +++ b/clinica/iotools/converters/adni_to_bids/adni_json.py @@ -4,12 +4,12 @@ from pathlib import Path -# Maps names of extracted metadata to their proper BIDS names +# Maps names of extracted metadata to their proper BIDS names METADATA_NAME_MAPPING = { - "acquisition_type": "MRAcquisitionType", - "pulse_sequence": "PulseSequenceType", - "manufacturer": "Manufacturer", - "field_strength": "MagneticFieldStrength", + "acquisition_type": "MRAcquisitionType", + "pulse_sequence": "PulseSequenceType", + "manufacturer": "Manufacturer", + "field_strength": "MagneticFieldStrength", } @@ -22,13 +22,17 @@ def _read_xml_files(subj_ids: list = [], xml_path: str = "") -> list: from os import path if subj_ids: xml_files = [] - xml_regex = [path.join(xml_path, ("ADNI_" + e + "*.xml")) for e in subj_ids] + xml_regex = [ + path.join(xml_path, ("ADNI_" + e + "*.xml")) for e in subj_ids + ] for subj_files in xml_regex: xml_files.extend(glob(subj_files)) else: xml_files = glob("Clinica_processed_metadata/ADNI_*.xml") if len(xml_files) == 0: - raise IndexError("No ADNI xml files were found for reading the metadata.") + raise IndexError( + "No ADNI xml files were found for reading the metadata." + ) return xml_files @@ -38,7 +42,10 @@ def _check_xml_tag(el_tag: str, exp_tag: str): raise ValueError(f"Bad tag: expected {exp_tag}, got {el_tag}") -def _check_xml_nb_children(xml_el: xml.etree.ElementTree.Element, exp_nb_children: int): +def _check_xml_nb_children( + xml_el: xml.etree.ElementTree.Element, + exp_nb_children: int +): """Check that the given XML element has the expected number of children. Raise a `ValueError` otherwise. """ @@ -46,13 +53,15 @@ def _check_xml_nb_children(xml_el: xml.etree.ElementTree.Element, exp_nb_childre if isinstance(exp_nb_children, int): if nb_children != exp_nb_children: raise ValueError( - f"Bad number of children for <{xml_el.tag}>: got {nb_children} != {exp_nb_children}" + f"Bad number of children for <{xml_el.tag}>: " + f"got {nb_children} != {exp_nb_children}" ) else: # container if nb_children not in exp_nb_children: raise ValueError( - f"Bad number of children for <{xml_el.tag}>: got {nb_children}, not in {exp_nb_children}" + f"Bad number of children for <{xml_el.tag}>: " + f"got {nb_children}, not in {exp_nb_children}" ) @@ -70,12 +79,18 @@ def _check_xml( return xml_el -def _parse_project(root: xml.etree.ElementTree.Element) -> xml.etree.ElementTree.Element: - """Check that root has 1 child: having 4 children, extract project from them. +def _parse_project( + root: xml.etree.ElementTree.Element +) -> xml.etree.ElementTree.Element: + """Check that root has 1 child: having 4 children, + extract project from them. Check that the project identifier contains ADNI. """ if len(root) != 1: - raise ValueError(f"XML root should have only one child. {len(root)} children found.") + raise ValueError( + "XML root should have only one child. " + f"{len(root)} children found." + ) project = _check_xml(root[0], "project", 4) _check_xml_project_identifier(project, "ADNI") return project @@ -116,48 +131,65 @@ def _check_derived_image_xml(derived: xml.etree.ElementTree.Element): # 2 cases if derived[-1].tag == "creationDate": assert ( - _check_xml_and_get_text(derived[-1], "creationDate") == "0000-00-00" + _check_xml_and_get_text( + derived[-1], "creationDate" + ) == "0000-00-00" ), "creaDate" else: assert ( - _check_xml_and_get_text(derived[-2], "creationDate") == "0000-00-00" + _check_xml_and_get_text( + derived[-2], "creationDate" + ) == "0000-00-00" ), "creaDate" -def _get_derived_image_metadata(derived: xml.etree.ElementTree.Element) -> dict: - """Return metadata for derived images after performing some sanity checks.""" +def _get_derived_image_metadata( + derived: xml.etree.ElementTree.Element +) -> dict: + """Return metadata for derived images after + performing some sanity checks. + """ _check_derived_image_xml(derived) return { - "image_proc_id": _check_xml_and_get_text(derived[0], "imageUID", cast=int), + "image_proc_id": _check_xml_and_get_text( + derived[0], "imageUID", cast=int + ), "image_proc_desc": _check_xml_and_get_text( derived[1], "processedDataLabel" ).strip(), } -def _check_xml_project_identifier(project: xml.etree.ElementTree.Element, expected: str): +def _check_xml_project_identifier( + project: xml.etree.ElementTree.Element, + expected: str +): """Check the project identifier.""" if _check_xml_and_get_text(project[0], "projectIdentifier") != expected: raise ValueError(f"Not {expected} cohort") -def _get_original_image_metadata(original_image: xml.etree.ElementTree.Element) -> dict: +def _get_original_image_metadata( + original_image: xml.etree.ElementTree.Element +) -> dict: """Get original image metadata.""" if not isinstance(original_image, list): return _get_image_metadata(original_image) # only keep common metadata (remaining will be None) and original_images_metadata = [_get_image_metadata(_) for _ in original_image] - original_images_metadata = {_.pop("image_orig_id"): _ for _ in original_images_metadata} - original_image_metadata = {"image_orig_id": "|".join(map(str, original_images_metadata.keys()))} + original_images_metadata = { + _.pop("image_orig_id"): _ for _ in original_images_metadata + } + original_image_metadata = { + "image_orig_id": "|".join( + map(str, original_images_metadata.keys()) + ) + } for k in next(iter(original_images_metadata.values())).keys(): # loop on original images metadata and only add consistent metadata set_v = set(_[k] for _ in original_images_metadata.values()) if len(set_v) == 1: original_image_metadata[k] = set_v.pop() - # else: - # orig_img_d[k] = '|'.join(map(str, set_v)) # keep all values with '|' - # assert {**get_img_metadata(orig_img), 'IMAGE_ORIG_ID': -1} == \ - # {**get_img_metadata(orig_img_bis), 'IMAGE_ORIG_ID': -1}, 'Not consistent multiple derived scans' return original_image_metadata @@ -170,12 +202,17 @@ def _check_processed_image_rating( consistent with the rating of the original image. """ from clinica.utils.stream import cprint + rating = original_image_metadata.get( + "image_orig_rating", None + ) if processed_image_rating is not None: - if processed_image_rating != original_image_metadata.get("image_orig_rating", None): + if processed_image_rating != rating: cprint( - msg=(f"Image rating for processed image {derived_image_metadata.get('image_proc_id')} not " - "consistent with rating of original image"), - lvl="info", + msg=( + "Image rating for processed image " + f"{derived_image_metadata.get('image_proc_id')} not " + "consistent with rating of original image" + ), lvl="info", ) @@ -191,7 +228,9 @@ def _get_root_from_xml_path(xml_path: str) -> xml.etree.ElementTree.Element: return root -def _parse_series(study: xml.etree.ElementTree.Element) -> xml.etree.ElementTree.Element: +def _parse_series( + study: xml.etree.ElementTree.Element +) -> xml.etree.ElementTree.Element: """Get series xml element from study xml element.""" expected_length = 3 if len(study) == 7 else 4 return _check_xml(study[5], "series", expected_length) @@ -207,10 +246,14 @@ def _parse_derived_images( # we can have 1 or 2 related imgs... series_meta = _check_xml(series[-1], "seriesLevelMeta", {3, 4}) related_image = _check_xml(series_meta[2], "relatedImageDetail", 1) - original_image = _check_xml(related_image[0], "originalRelatedImage", {3, 4}) + original_image = _check_xml( + related_image[0], "originalRelatedImage", {3, 4} + ) if len(series_meta) == 4: # 2 original images - related_image_bis = _check_xml(series_meta[-1], "relatedImageDetail", 1) + related_image_bis = _check_xml( + series_meta[-1], "relatedImageDetail", 1 + ) original_image_bis = _check_xml( related_image_bis[0], "originalRelatedImage", len(original_image) ) @@ -219,7 +262,6 @@ def _parse_derived_images( pipeline_name = _check_xml_and_get_text( _check_xml(series_meta[0], "annotation", 1)[0], "text" ) - # assert pipeline_name == 'Grinder Pipeline', f'Wrong pipeline name: {pipe_name}' deriv = _check_xml(series_meta[1], "derivedProduct") derived_metadata = {"image_proc_pipe": pipeline_name.strip()} derived_metadata.update(_get_derived_image_metadata(deriv)) @@ -233,21 +275,25 @@ def _check_image(img: xml.etree.ElementTree.Element): "imagingProtocol", # for original image metadata "originalRelatedImage", # for processed image }: - raise ValueError(f"Bad image tag <{img.tag}>. " - "Should be either 'imagingProtocol' " - "or 'originalRelatedImage'.") + raise ValueError( + f"Bad image tag <{img.tag}>. " + "Should be either 'imagingProtocol' " + "or 'originalRelatedImage'." + ) if len(img) not in {3, 4}: - raise ValueError(f"Image XML element has {len(img)} children. " - "Expected either 3 or 4.") + raise ValueError( + f"Image XML element has {len(img)} children. " + "Expected either 3 or 4." + ) def _clean_protocol_metadata(protocol_metadata: dict) -> dict: """Replace confusing '|' (reserved for separation btw elements) with space. Only 1 manufacturer with this apparently but... """ - return {k: v.replace("|", " ") - if isinstance(v, str) - else v # there are some None (not str) + return { + k: v.replace("|", " ") + if isinstance(v, str) else v # there are some None (not str) for k, v in protocol_metadata.items() } @@ -256,7 +302,7 @@ def _filter_metadata(metadata: dict) -> dict: """Filter and clean a given metadata dictionary according to the METADATA_NAME_MAPPING dictionary.""" filtered = dict() - for k,v in metadata.items(): + for k, v in metadata.items(): if k in METADATA_NAME_MAPPING: filtered[METADATA_NAME_MAPPING[k]] = v return filtered @@ -266,8 +312,11 @@ def _get_image_metadata(img: xml.etree.ElementTree.Element) -> dict: """Return information on original image as dict of metadata.""" _check_image(img) protocol = _check_xml(img[2], "protocolTerm") # do not check children - img_rating_val = _get_image_rating(img[3]) if len(img) == 4 else None - assert all(p.tag == "protocol" and "term" in p.attrib.keys() for p in protocol) + assert all( + p.tag == "protocol" + and "term" in p.attrib.keys() + for p in protocol + ) protocol_metadata = { "_".join(p.attrib["term"].split()).lower(): p.text for p in protocol } @@ -282,34 +331,51 @@ def _get_image_metadata(img: xml.etree.ElementTree.Element) -> dict: } -def _get_image_rating(image_rating: xml.etree.ElementTree.Element) -> Optional[int]: +def _get_image_rating( + image_rating: xml.etree.ElementTree.Element +) -> Optional[int]: """Get the image rating value as an integer from the xml element.""" if image_rating.tag != "imageRating": return None _check_xml(image_rating, "imageRating", 2) - image_rating_desc = _check_xml_and_get_text(image_rating[0], "ratingDescription") - image_rating_val = _check_xml_and_get_text(image_rating[1], "value", cast=int) + image_rating_desc = _check_xml_and_get_text( + image_rating[0], "ratingDescription" + ) + image_rating_val = _check_xml_and_get_text( + image_rating[1], "value", cast=int + ) assert image_rating_desc == str(image_rating_val) return image_rating_val -def _check_modality(study: xml.etree.ElementTree.Element, expected_modality: str): +def _check_modality( + study: xml.etree.ElementTree.Element, + expected_modality: str +): """Check that the modality of the given study is the expected one.""" series = _parse_series(study) modality = _check_xml_and_get_text(series[1], "modality") if modality != expected_modality: - raise ValueError(f"Unexpected modality {modality}, expected {expected_modality}.") + raise ValueError( + f"Unexpected modality {modality}, expected {expected_modality}." + ) -def _parse_subject(project: xml.etree.ElementTree.Element) -> Tuple[str, xml.etree.ElementTree.Element]: +def _parse_subject( + project: xml.etree.ElementTree.Element +) -> Tuple[str, xml.etree.ElementTree.Element]: """From the project xml element, parse the subject and subject id.""" subject = _check_xml(project[-1], "subject", {5, 7}) # with APOE or not subject_id = _check_xml_and_get_text(subject[0], "subjectIdentifier") return subject_id, subject -def _parse_study(subject: xml.etree.ElementTree.Element) -> xml.etree.ElementTree.Element: - """Parse the study from the subject xml element and check that it is MRI.""" +def _parse_study( + subject: xml.etree.ElementTree.Element +) -> xml.etree.ElementTree.Element: + """Parse the study from the subject xml element + and check that it is MRI. + """ study = _check_xml(subject[-1], "study", {6, 7}) _check_modality(study, "MRI") return study @@ -335,7 +401,8 @@ def _parse_xml_file(xml_path: str) -> dict: subject_id, subject = _parse_subject(project) study = _parse_study(subject) series = _parse_series(study) - original_image, processed_image_rating, derived_image_metadata = _parse_images(study) + original_image, processed_image_rating, derived_image_metadata =\ + _parse_images(study) original_image_metadata = _get_original_image_metadata(original_image) _check_processed_image_rating( processed_image_rating, original_image_metadata, derived_image_metadata @@ -349,8 +416,11 @@ def _parse_xml_file(xml_path: str) -> dict: **original_image_metadata, **derived_image_metadata, } - if "image_proc_id" not in scan_metadata and "image_orig_id" not in scan_metadata: - raise ValueError("Scan metadata for subject {subject_id} has no image ID.") + if ("image_proc_id" not in scan_metadata + and "image_orig_id" not in scan_metadata): + raise ValueError( + f"Scan metadata for subject {subject_id} has no image ID." + ) return scan_metadata @@ -368,11 +438,14 @@ def __call__(self, *args, **kwargs): def _run_parsers(xml_files: list) -> Tuple[list, dict]: """Run the parser `_parse_xml_file` on the list of files `xml_files`. - Returns a tuple consisting if parsed images metadata and captured exceptions. + Returns a tuple consisting if parsed images + metadata and captured exceptions. + + .. note:: + Use multiprocessing / multithreading for parsing the files? + Not sure we will get a huge performance boost by doing that tbh. """ import os - # from concurrent.futures import ProcessPoolExecutor, as_completed #ThreadPoolExecutor - from math import ceil parser = func_with_exception(_parse_xml_file) imgs_with_excep = dict( zip( @@ -418,17 +491,22 @@ def _get_existing_scan_dataframe( subj_path: str, session: str ) -> Union[Tuple[pd.DataFrame, Path], Tuple[None, Path]]: - """Retrieve existing scan dataframe at the given `subj_path`, and the given `session`. + """Retrieve existing scan dataframe at the given + `subj_path`, and the given `session`. If no existing scan file is found, a warning is given. """ import warnings subj_id = Path(subj_path).name - scans_tsv_path = Path(subj_path) / session / f"{subj_id}_{session}_scans.tsv" + scans_tsv_path = ( + Path(subj_path) / session / f"{subj_id}_{session}_scans.tsv" + ) if scans_tsv_path.exists(): df_scans = pd.read_csv(scans_tsv_path, sep="\t") df_scans["scan_id"] = df_scans["scan_id"].astype("Int64") return df_scans, scans_tsv_path - warnings.warn(f"No scan tsv file for subject {subj_id} and session {session}") + warnings.warn( + f"No scan tsv file for subject {subj_id} and session {session}" + ) return None, scans_tsv_path @@ -437,7 +515,9 @@ def _merge_scan_and_metadata( df_meta: pd.DataFrame, strategy: dict ) -> pd.DataFrame: - """Perform a merge between the two provided dataframe according to the strategy.""" + """Perform a merge between the two provided dataframe + according to the strategy. + """ return pd.merge(df_scans, df_meta, **strategy) @@ -467,41 +547,64 @@ def _add_json_scan_metadata( if isinstance(metadata, str): metadata = json.loads(metadata) filtered_metadata = { - k:v for k,v in metadata.items() + k: v for k, v in metadata.items() if k in METADATA_NAME_MAPPING.values() } updated_metadata = {**existing_metadata, **filtered_metadata} if not keep_none: - updated_metadata = {k:v for k,v in updated_metadata.items() if v is not None} + updated_metadata = { + k: v for k, v in updated_metadata.items() if v is not None + } if len(updated_metadata) > 0: with open(json_path, "w") as fp: json.dump(updated_metadata, fp, indent=indent) -def _add_metadata_to_scans(df_meta: pd.DataFrame, bids_subjs_paths: list) -> None: +def _add_metadata_to_scans( + df_meta: pd.DataFrame, + bids_subjs_paths: list +) -> None: """Add the metadata to the appropriate tsv and json files.""" from clinica.iotools.bids_utils import get_bids_sess_list - MERGE_STRATEGY = {"how": "left", "left_on": "scan_id", "right_on": "T1w_scan_id"} + MERGE_STRATEGY = { + "how": "left", + "left_on": "scan_id", + "right_on": "T1w_scan_id" + } for subj_path in bids_subjs_paths: sess_list = get_bids_sess_list(subj_path) if sess_list: for sess in sess_list: - df_scans, scans_tsv_path = _get_existing_scan_dataframe(subj_path, sess) + df_scans, scans_tsv_path = _get_existing_scan_dataframe( + subj_path, sess + ) if df_scans is not None: columns_to_keep = list(df_scans.columns) + ['acq_time'] - df_merged = _merge_scan_and_metadata(df_scans, df_meta, MERGE_STRATEGY) + df_merged = _merge_scan_and_metadata( + df_scans, df_meta, MERGE_STRATEGY + ) for _, scan_row in df_merged.iterrows(): - scan_path = Path(subj_path) / sess / scan_row["filename"] - json_path = _get_json_filename_from_scan_filename(scan_path) + scan_path = ( + Path(subj_path) / sess / scan_row["filename"] + ) + json_path = _get_json_filename_from_scan_filename( + scan_path + ) _add_json_scan_metadata(json_path, scan_row.to_json()) df_merged[columns_to_keep].to_csv(scans_tsv_path, sep="\t") return None -def create_json_metadata(bids_subjs_paths: str, bids_ids: list, xml_path: str) -> None: +def create_json_metadata( + bids_subjs_paths: str, + bids_ids: list, + xml_path: str +) -> None: """Create json metadata dictionary and add the metadata to the appropriate files in the BIDS hierarchy.""" - from clinica.iotools.converters.adni_to_bids.adni_utils import bids_id_to_loni + from clinica.iotools.converters.adni_to_bids.adni_utils import ( + bids_id_to_loni + ) loni_ids = [bids_id_to_loni(bids_id) for bids_id in bids_ids] xml_files = _read_xml_files(loni_ids, xml_path) imgs, exe = _run_parsers(xml_files) diff --git a/clinica/iotools/converters/adni_to_bids/adni_to_bids.py b/clinica/iotools/converters/adni_to_bids/adni_to_bids.py index 6f1d209609..eedc9ef513 100644 --- a/clinica/iotools/converters/adni_to_bids/adni_to_bids.py +++ b/clinica/iotools/converters/adni_to_bids/adni_to_bids.py @@ -137,12 +137,14 @@ def convert_clinical_data( if xml_path is not None: if os.path.exists(xml_path): - create_json_metadata(bids_subjs_paths, bids_ids, xml_path) - else: - cprint( - msg=f"Clinica was unable to find {xml_path}, skipping xml metadata extraction.", - lvl="warning", - ) + create_json_metadata(bids_subjs_paths, bids_ids, xml_path) + else: + cprint( + msg=( + f"Clinica was unable to find {xml_path}, " + "skipping xml metadata extraction." + ), lvl="warning", + ) def convert_images( self, diff --git a/clinica/iotools/converters/adni_to_bids/adni_to_bids_cli.py b/clinica/iotools/converters/adni_to_bids/adni_to_bids_cli.py index 68c56e6e46..ff4cd659a5 100644 --- a/clinica/iotools/converters/adni_to_bids/adni_to_bids_cli.py +++ b/clinica/iotools/converters/adni_to_bids/adni_to_bids_cli.py @@ -33,8 +33,8 @@ help="Convert only the selected modality. By default, all available modalities are converted.", ) @click.option( - "-xml", "--xml_path", help="Path to the root directory containing the xml metadata." - ) + "-xml", "--xml_path", help="Path to the root directory containing the xml metadata." +) def cli( dataset_directory: str, clinical_data_directory: str, diff --git a/clinica/iotools/converters/adni_to_bids/adni_utils.py b/clinica/iotools/converters/adni_to_bids/adni_utils.py index fca27720f0..28b3d652ef 100644 --- a/clinica/iotools/converters/adni_to_bids/adni_utils.py +++ b/clinica/iotools/converters/adni_to_bids/adni_utils.py @@ -569,15 +569,14 @@ def remove_fields_duplicated(bids_fields): def bids_id_to_loni(bids_id: str) -> Union[str, None]: - """ - Convert a subject id of the form sub-ADNI000S0000 back to original format 000_S_0000 - """ - import re - - ids = re.findall("\d+", bids_id) - if len(ids) == 2: - return ids[0] + "_S_" + ids[1] - return None + """Convert a subject id of the form sub-ADNI000S0000 + back to original format 000_S_0000 + """ + import re + ids = re.findall("\d+", bids_id) # noqa + if len(ids) == 2: + return ids[0] + "_S_" + ids[1] + return None def filter_subj_bids(df_files, location, bids_ids): diff --git a/test/tests/iotools/converters/adni_to_bids/test_adni_json.py b/test/tests/iotools/converters/adni_to_bids/test_adni_json.py index cbacaa87c6..9f2732697d 100644 --- a/test/tests/iotools/converters/adni_to_bids/test_adni_json.py +++ b/test/tests/iotools/converters/adni_to_bids/test_adni_json.py @@ -68,10 +68,10 @@ def test_check_xml_and_get_text(basic_xml_tree): from clinica.iotools.converters.adni_to_bids.adni_json import _check_xml_and_get_text xml_leaf = ET.Element("leaf") xml_leaf.text = "12" - assert _check_xml_and_get_text(xml_leaf, "leaf") == "12" - assert _check_xml_and_get_text(xml_leaf, "leaf", cast=int) == 12 + assert _check_xml_and_get_text(xml_leaf, "leaf") == "12" + assert _check_xml_and_get_text(xml_leaf, "leaf", cast=int) == 12 with pytest.raises(ValueError, match="Bad number of children for "): - _check_xml_and_get_text(basic_xml_tree, "root") + _check_xml_and_get_text(basic_xml_tree, "root") def test_read_xml_files(tmp_path): @@ -105,22 +105,22 @@ def _load_xml_from_template( ) -> str: from string import Template other_substitutes = { - "study_id": 100, - "series_id": 200, - "image_id": 300, - "proc_id": 3615, + "study_id": 100, + "series_id": 200, + "image_id": 300, + "proc_id": 3615, } temp = Path(f"./data/{template_id}_template.xml").read_text() temp = Template(temp.replace("\n", "")) return temp.safe_substitute( - project=project, modality=modality, - acq_time=acq_time, **other_substitutes + project=project, modality=modality, + acq_time=acq_time, **other_substitutes ) def _write_xml_example( base_path: Path, - template_id: str ="ADNI_123_S_4567", + template_id: str = "ADNI_123_S_4567", suffix: Optional[str] = None, **kwargs ) -> Path: @@ -147,15 +147,15 @@ def expected_image_metadata(template_id): expected = {} if template_id == "ADNI_234_S_5678": expected = { - 'image_proc_pipe': 'Grinder Pipeline', - 'image_proc_id': 3615, - 'image_proc_desc': 'MT1; GradWarp; N3m' + 'image_proc_pipe': 'Grinder Pipeline', + 'image_proc_id': 3615, + 'image_proc_desc': 'MT1; GradWarp; N3m' } elif template_id == "ADNI_345_S_6789": expected = { - 'image_proc_pipe': 'UCSD ADNI Pipeline', - 'image_proc_id': 3615, - 'image_proc_desc': 'MPR; GradWarp; N3; Scaled' + 'image_proc_pipe': 'UCSD ADNI Pipeline', + 'image_proc_id': 3615, + 'image_proc_desc': 'MPR; GradWarp; N3; Scaled' } return expected @@ -164,19 +164,24 @@ def expected_image_metadata(template_id): def test_parsing(tmp_path, template_id, expected_image_metadata): """Test function `_get_root_from_xml_path`.""" from clinica.iotools.converters.adni_to_bids.adni_json import ( - _get_root_from_xml_path, _parse_project, _parse_subject, - _parse_study, _parse_series, _parse_images, + _get_root_from_xml_path, _parse_project, _parse_subject, + _parse_study, _parse_series, _parse_images, ) expected_subject_id = template_id[5:] xml_files = { - "correct": _write_xml_example( - tmp_path, template_id=template_id, suffix="correct"), - "bad_project": _write_xml_example( - tmp_path, project="foo", - template_id=template_id, suffix="bad_project"), - "bad_study": _write_xml_example( - tmp_path, modality="bar", - template_id=template_id, suffix="bad_study"), + "correct": _write_xml_example( + tmp_path, template_id=template_id, suffix="correct" + ), + "bad_project": _write_xml_example( + tmp_path, project="foo", + template_id=template_id, + suffix="bad_project" + ), + "bad_study": _write_xml_example( + tmp_path, modality="bar", + template_id=template_id, + suffix="bad_study" + ), } roots = {k: _get_root_from_xml_path(v) for k, v in xml_files.items()} @@ -231,9 +236,9 @@ def test_parsing(tmp_path, template_id, expected_image_metadata): @pytest.fixture def expected_mprage(template_id): expected = { - "ADNI_123_S_4567": "Accelerated Sagittal MPRAGE", - "ADNI_234_S_5678": "MPRAGE GRAPPA2", - "ADNI_345_S_6789": "MP-RAGE", + "ADNI_123_S_4567": "Accelerated Sagittal MPRAGE", + "ADNI_234_S_5678": "MPRAGE GRAPPA2", + "ADNI_345_S_6789": "MP-RAGE", } return expected[template_id] @@ -309,18 +314,18 @@ def test_add_json_scan_metadata(tmp_path, keep_none): with open(json_path, "w") as fp: json.dump(existing_metadata, fp) new_metadata = { - "PulseSequenceType": None, # Kept or not depending on keep_none - "Manufacturer": "SIEMENS", - "meta_6": "meta6", # Will be filtered out - "meta_2": "foo", # also filtered out + "PulseSequenceType": None, # Kept or not depending on keep_none + "Manufacturer": "SIEMENS", + "meta_6": "meta6", # Will be filtered out + "meta_2": "foo", # also filtered out } _add_json_scan_metadata(json_path, new_metadata, keep_none=keep_none) with open(json_path, "r") as fp: merged = json.load(fp) - expected_keys = set( - ["MRAcquisitionType", "meta_2", "MagneticFieldStrength", - "PulseSequenceType", "Manufacturer"] - ) + expected_keys = set([ + "MRAcquisitionType", "meta_2", "MagneticFieldStrength", + "PulseSequenceType", "Manufacturer" + ]) if not keep_none: expected_keys.remove("PulseSequenceType") assert set(merged.keys()) == expected_keys