Skip to content

Commit

Permalink
[FIX] example with load KeepControl data from OpenNeuro
Browse files Browse the repository at this point in the history
  • Loading branch information
JuliusWelzel committed Jul 18, 2024
1 parent 9179115 commit 0586e6b
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 258 deletions.
256 changes: 65 additions & 191 deletions examples/modules_02_icd_keepControl.ipynb

Large diffs are not rendered by default.

151 changes: 84 additions & 67 deletions ngmt/datasets/keepcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@
import pandas as pd
from pathlib import Path
import openneuro
import os
from typing import Union, Optional
from ngmt.utils.ngmt_dataclass import NGMTRecording
from ngmt.utils.ngmt_dataclass import REQUIRED_COLUMNS

# Dict of valid tracked points for the Keep Control dataset for each tracking system
VALID_TRACKED_POINTS = {
"omc": [
'l_toe', 'l_heel', 'l_ank', 'l_sk1', 'l_sk2', 'l_sk3', 'l_sk4', 'l_th1', 'l_th2',
'l_th3', 'l_th4', 'r_toe', 'r_heel', 'r_ank', 'r_sk1', 'r_sk2', 'r_sk3', 'r_sk4',
'r_th1', 'r_th2', 'r_th3', 'r_th4', 'l_asis', 'r_asis', 'l_psis', 'r_psis',
'm_ster1', 'm_ster2', 'm_ster3', 'l_sho', 'l_ua', 'l_elbl', 'l_frm', 'l_wrr',
'l_wru', 'l_hand', 'r_sho', 'r_ua', 'r_elbl', 'r_frm', 'r_wrr', 'r_wru', 'r_hand',
'lf_hd', 'rf_hd', 'lb_hd', 'rb_hd', 'start_1', 'start_2', 'end_1', 'end_2'
],
"imu": [
'head', 'sternum', 'left_upper_arm', 'left_fore_arm', 'right_upper_arm',
'right_fore_arm', 'pelvis', 'left_thigh', 'left_shank', 'left_foot',
'right_thigh', 'right_shank', 'right_foot', 'left_ankle', 'right_ankle'
]
}

def fetch_dataset(
dataset_path: str | Path = Path(__file__).parent / "_keepcontrol",
) -> None:
Expand Down Expand Up @@ -38,95 +55,95 @@ def fetch_dataset(
return

def load_recording(
file_name: Path,
tracking_systems: str | list[str],
tracked_points: str | list[str] | dict[str, str] | dict[str, list[str]],
dataset_path: str | Path = Path(__file__).parent / "_keepcontrol",
id: str = "pp001",
task: str = "walkSlow",
tracking_systems: Union[str, list[str]] = ["imu", "omc"],
tracked_points: Optional[Union[None, str, list[str]]] = None,
):
"""
Load a recording from the Keep Control validation study.
Args:
file_name (pathlib.Path): The absolute or relative path to the data file.
tracking_systems (str or list of str) : A string or list of strings of tracking systems for which data are to be returned.
tracked_points (str or list of str or dict[str, str] or dict[str, list of str]) :
Defines for which tracked points data are to be returned.
If a string or list of strings is provided, then these will be mapped to each requested tracking system.
If a dictionary is provided, it should map each tracking system to either a single tracked point or a list of tracked points.
dataset_path (str or Path, optional): The path to the dataset. Defaults to the "_keepcontrol" directory in the same directory as this file.
id (str): The ID of the recording.
tracking_systems (str or list of str): A string or list of strings representing the tracking systems for which data should be returned.
tracked_points (None, str or list of str, optional): The tracked points of interest. If None, all tracked points will be returned. Defaults to None.
Returns:
NGMTRecording : An instance of the NGMTRecording dataclass containing the loaded data and channels.
NGMTRecording: An instance of the NGMTRecording dataclass containing the loaded data and channels.
"""

# Fetch the dataset if it does not exist
progressbar = False if not progressbar else progressbar
file_path = Path(dataset_path) / cohort / file_name
if not file_path.exists():
if not dataset_path.exists():
fetch_dataset()

# check if id contains sub or sub- substring, if so replace it with ''
id = id.replace("sub", "").replace("-", "")

# check if task contains task or task- substring, if so replace it with ''
task = task.replace("task", "").replace("-", "")


# Put tracking systems in a list
if isinstance(tracking_systems, str):
tracking_systems = [tracking_systems]

# check if tracked points has been specified, if not use all tracked points
if tracked_points is None:
tracked_points = {tracksys: VALID_TRACKED_POINTS[tracksys] for tracksys in tracking_systems}
# Tracked points will be a dictionary mapping
# each tracking system to a list of tracked points of interest
if isinstance(tracked_points, str):
tracked_points = [tracked_points]
if isinstance(tracked_points, list):
tracked_points = {tracksys: tracked_points for tracksys in tracking_systems}
for k, v in tracked_points.items():
if isinstance(v, str):
tracked_points[k] = [v]
# use the VALID_TRACKED_POINTS dictionary to get the valid tracked points for each tracking system
# return error of some tracked_points are not valid
# print which of the specified tracked points are not valid
for tracksys in tracking_systems:
if not all(tracked_point in VALID_TRACKED_POINTS[tracksys] for tracked_point in tracked_points[tracksys]):
print(f"Invalid tracked points for tracking system {tracksys}.")
print(f"Valid tracked points are: {VALID_TRACKED_POINTS[tracksys]}")
invalid_points = [tracked_point for tracked_point in tracked_points[tracksys] if tracked_point not in VALID_TRACKED_POINTS[tracksys]]
print(f"Invalid tracked points are: {invalid_points}")
return

# From the file_name, extract the tracking system
search_str = "_tracksys-"
idx_from = str(file_name).find(search_str) + len(search_str)
idx_to = idx_from + str(file_name)[idx_from:].find("_")
current_tracksys = str(file_name)[idx_from:idx_to]

# Initialize the data and channels dictionaroes
# Load data and channels for each tracking system
data_dict, channels_dict = {}, {}
for tracksys in tracking_systems:
# Set current filename
current_file_name = str(file_name).replace(
f"{search_str}{current_tracksys}", f"{search_str}{tracksys}"
)
if os.path.isfile(current_file_name):
# Read the data and channels info into a pandas DataFrame
df_data = pd.read_csv(current_file_name, header=0, sep="\t")
df_channels = pd.read_csv(
current_file_name.replace("_motion.tsv", "_channels.tsv"),
header=0,
sep="\t",
)

# Now select only for the tracked points of interest
df_data = df_data.loc[
:,
[
col
for col in df_data.columns
if any(
[
tracked_point in col
for tracked_point in tracked_points[tracksys]
]
)
],
]
if df_data.empty:
print(f"Specified tracked point not found in the tracking system {tracksys}.")

df_channels = df_channels[
df_channels["tracked_point"].str.contains("|".join(tracked_points[tracksys]))
]

# Add sampling frequency to the channels DataFrame
df_channels["sampling_frequency"] = 200

# Put data and channels in output dictionaries
col_names = [c for c in REQUIRED_COLUMNS] + [
c for c in df_channels.columns if c not in REQUIRED_COLUMNS
]
data_dict[tracksys] = df_data
channels_dict[tracksys] = df_channels[col_names]
return NGMTRecording(data=data_dict, channels=channels_dict)
# Find avaliable files for the give ID and task and tracking system
file_name = list(dataset_path.glob(f"sub-{id}/motion/sub-{id}_task-{task}_tracksys-{tracksys}_*motion.tsv"))
# check if file exists, if not print message to user and return
if not file_name:
print(f"No files found for ID {id}, task {task}, and tracking system {tracksys}.")
return
# check if multiple files are found, if so print message to user and return
if len(file_name) > 1:
print(f"Multiple files found for ID {id}, task {task}, and tracking system {tracksys}.")
return

# Load the data and channels for the tracking system
df_data = pd.read_csv(file_name[0], sep="\t")
df_channels = pd.read_csv(file_name[0].parent / f"sub-{id}_task-{task}_tracksys-{tracksys}_channels.tsv", sep="\t")

# filter the data and channels to only include the tracked points of interest
df_channels = df_channels[df_channels["tracked_point"].isin(tracked_points[tracksys])]

# only keep df_data columns that are in df_channels
col_names = df_channels["name"].values
df_data = df_data[col_names]

# transform the data and channels into a dictionary for the NGMTRecording dataclass
data_dict[tracksys] = df_data
channels_dict[tracksys] = df_channels

# construct data class
recording = NGMTRecording(data=data_dict, channels=channels_dict)

# add information about the recording to the data class
recording.add_info("Subject", id)
recording.add_info("Task", task)

return recording

0 comments on commit 0586e6b

Please sign in to comment.