Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetch keep control data #99

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ create_changelog.py
changelog.yml

# example data
kiemat/datasets/_*
kielmat/datasets/_*
ngmt/datasets/_*
132 changes: 73 additions & 59 deletions examples/modules_02_icd.ipynb

Large diffs are not rendered by default.

180 changes: 119 additions & 61 deletions kielmat/datasets/keepcontrol.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,146 @@
import numpy as np
import pandas as pd
import pathlib
import os
from pathlib import Path
import openneuro
from typing import Union, Optional
from kielmat.utils.kielmat_dataclass import KielMATRecording
from kielmat.utils.kielmat_dataclass import REQUIRED_COLUMNS

# Dict of valid tracked points for the Keep Control dataset for each tracking system
VALID_TRACKED_POINTS = {
"omc": [
'l_toe', 'l_heel', 'l_ank', 'l_sk1', 'l_sk2', 'l_sk3', 'l_sk4', 'l_th1', 'l_th2',
'l_th3', 'l_th4', 'r_toe', 'r_heel', 'r_ank', 'r_sk1', 'r_sk2', 'r_sk3', 'r_sk4',
'r_th1', 'r_th2', 'r_th3', 'r_th4', 'l_asis', 'r_asis', 'l_psis', 'r_psis',
'm_ster1', 'm_ster2', 'm_ster3', 'l_sho', 'l_ua', 'l_elbl', 'l_frm', 'l_wrr',
'l_wru', 'l_hand', 'r_sho', 'r_ua', 'r_elbl', 'r_frm', 'r_wrr', 'r_wru', 'r_hand',
'lf_hd', 'rf_hd', 'lb_hd', 'rb_hd', 'start_1', 'start_2', 'end_1', 'end_2'
],
"imu": [
'head', 'sternum', 'left_upper_arm', 'left_fore_arm', 'right_upper_arm',
'right_fore_arm', 'pelvis', 'left_thigh', 'left_shank', 'left_foot',
'right_thigh', 'right_shank', 'right_foot', 'left_ankle', 'right_ankle'
]
}

def fetch_dataset(
dataset_path: str | Path = Path(__file__).parent / "_keepcontrol",
) -> None:
"""Fetch the Keep Control dataset from the OpenNeuro repository.
Args:
progressbar (bool, optional): Whether to display a progressbar. Defaults to True.
dataset_path (str | Path, optional): The path where the dataset is stored. Defaults to Path(__file__).parent/"_keepcontrol".
"""
dataset_path = Path(dataset_path) if isinstance(dataset_path, str) else dataset_path

# Check if target folder exists, if not create it
if not dataset_path.exists():
dataset_path.parent.joinpath("_keepcontrol").mkdir(parents=True, exist_ok=True)

# check if the dataset has already been downloaded (directory is not empty), if not download it
if not any(dataset_path.iterdir()):
# Download the dataset using openneuro-py
openneuro.download(
dataset="ds005258", # this is the example Keep Control dataset on OpenNeuro, maintained by Julius Welzel
target_dir=dataset_path,
)

print(f"Dataset has been downloaded to {dataset_path}")

# print message to user if dataset has already been downloaded
else:
print(f"Dataset has already been downloaded to {dataset_path}")

return

def load_recording(
file_name: pathlib.Path,
tracking_systems: str | list[str],
tracked_points: str | list[str] | dict[str, str] | dict[str, list[str]],
dataset_path: str | Path = Path(__file__).parent / "_keepcontrol",
id: str = "pp001",
task: str = "walkSlow",
tracking_systems: Union[str, list[str]] = ["imu", "omc"],
tracked_points: Optional[Union[None, str, list[str]]] = None,
):
"""
Load a recording from the Keep Control validation study.

Args:
file_name (pathlib.Path): The absolute or relative path to the data file.
tracking_systems (str or list of str) : A string or list of strings of tracking systems for which data are to be returned.
tracked_points (str or list of str or dict[str, str] or dict[str, list of str]) :
Defines for which tracked points data are to be returned.
If a string or list of strings is provided, then these will be mapped to each requested tracking system.
If a dictionary is provided, it should map each tracking system to either a single tracked point or a list of tracked points.

dataset_path (str or Path, optional): The path to the dataset. Defaults to the "_keepcontrol" directory in the same directory as this file.
id (str): The ID of the recording.
tracking_systems (str or list of str): A string or list of strings representing the tracking systems for which data should be returned.
tracked_points (None, str or list of str, optional): The tracked points of interest. If None, all tracked points will be returned. Defaults to None.
Returns:
KielMATRecording : An instance of the KielMATRecording dataclass containing the loaded data and channels.
KielMATRecording: An instance of the KielMATRecording dataclass containing the loaded data and channels.
"""

# Fetch the dataset if it does not exist
if not dataset_path.exists():
fetch_dataset()

# check if id contains sub or sub- substring, if so replace it with ''
id = id.replace("sub", "").replace("-", "")

# check if task contains task or task- substring, if so replace it with ''
task = task.replace("task", "").replace("-", "")


# Put tracking systems in a list
if isinstance(tracking_systems, str):
tracking_systems = [tracking_systems]

# check if tracked points has been specified, if not use all tracked points
if tracked_points is None:
tracked_points = {tracksys: VALID_TRACKED_POINTS[tracksys] for tracksys in tracking_systems}
# Tracked points will be a dictionary mapping
# each tracking system to a list of tracked points of interest
if isinstance(tracked_points, str):
tracked_points = [tracked_points]
if isinstance(tracked_points, list):
tracked_points = {tracksys: tracked_points for tracksys in tracking_systems}
for k, v in tracked_points.items():
if isinstance(v, str):
tracked_points[k] = [v]
# use the VALID_TRACKED_POINTS dictionary to get the valid tracked points for each tracking system
# return error of some tracked_points are not valid
# print which of the specified tracked points are not valid
for tracksys in tracking_systems:
if not all(tracked_point in VALID_TRACKED_POINTS[tracksys] for tracked_point in tracked_points[tracksys]):
print(f"Invalid tracked points for tracking system {tracksys}.")
print(f"Valid tracked points are: {VALID_TRACKED_POINTS[tracksys]}")
invalid_points = [tracked_point for tracked_point in tracked_points[tracksys] if tracked_point not in VALID_TRACKED_POINTS[tracksys]]
print(f"Invalid tracked points are: {invalid_points}")
return

# From the file_name, extract the tracking system
search_str = "_tracksys-"
idx_from = str(file_name).find(search_str) + len(search_str)
idx_to = idx_from + str(file_name)[idx_from:].find("_")
current_tracksys = str(file_name)[idx_from:idx_to]

# Initialize the data and channels dictionaroes
# Load data and channels for each tracking system
data_dict, channels_dict = {}, {}
for tracksys in tracking_systems:
# Set current filename
current_file_name = str(file_name).replace(
f"{search_str}{current_tracksys}", f"{search_str}{tracksys}"
)
if os.path.isfile(current_file_name):
# Read the data and channels info into a pandas DataFrame
df_data = pd.read_csv(current_file_name, header=0, sep="\t")
df_channels = pd.read_csv(
current_file_name.replace("_motion.tsv", "_channels.tsv"),
header=0,
sep="\t",
)

# Now select only for the tracked points of interest
df_data = df_data.loc[
:,
[
col
for col in df_data.columns
if any(
[
tracked_point in col
for tracked_point in tracked_points[tracksys]
]
)
],
]
df_channels = df_channels[
(df_channels["tracked_point"].isin(tracked_points[tracksys]))
]

# Put data and channels in output dictionaries
col_names = [c for c in REQUIRED_COLUMNS] + [
c for c in df_channels.columns if c not in REQUIRED_COLUMNS
]
data_dict[tracksys] = df_data
channels_dict[tracksys] = df_channels[col_names]
return KielMATRecording(data=data_dict, channels=channels_dict)
# Find avaliable files for the give ID and task and tracking system
file_name = list(dataset_path.glob(f"sub-{id}/motion/sub-{id}_task-{task}_tracksys-{tracksys}_*motion.tsv"))
# check if file exists, if not print message to user and return
if not file_name:
print(f"No files found for ID {id}, task {task}, and tracking system {tracksys}.")
return
# check if multiple files are found, if so print message to user and return
if len(file_name) > 1:
print(f"Multiple files found for ID {id}, task {task}, and tracking system {tracksys}.")
return

# Load the data and channels for the tracking system
df_data = pd.read_csv(file_name[0], sep="\t")
df_channels = pd.read_csv(file_name[0].parent / f"sub-{id}_task-{task}_tracksys-{tracksys}_channels.tsv", sep="\t")

# filter the data and channels to only include the tracked points of interest
df_channels = df_channels[df_channels["tracked_point"].isin(tracked_points[tracksys])]

# only keep df_data columns that are in df_channels
col_names = df_channels["name"].values
df_data = df_data[col_names]

# transform the data and channels into a dictionary for the KielMATRecording dataclass
data_dict[tracksys] = df_data
channels_dict[tracksys] = df_channels

# construct data class
recording = KielMATRecording(data=data_dict, channels=channels_dict)

# add information about the recording to the data class
recording.add_info("Subject", id)
recording.add_info("Task", task)

return recording