Skip to content

Commit

Permalink
Apply Black Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions committed Jul 19, 2024
1 parent e8ed594 commit 06b20b8
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 25 deletions.
126 changes: 103 additions & 23 deletions kielmat/datasets/keepcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,78 @@
# Dict of valid tracked points for the Keep Control dataset for each tracking system
VALID_TRACKED_POINTS = {
"omc": [
'l_toe', 'l_heel', 'l_ank', 'l_sk1', 'l_sk2', 'l_sk3', 'l_sk4', 'l_th1', 'l_th2',
'l_th3', 'l_th4', 'r_toe', 'r_heel', 'r_ank', 'r_sk1', 'r_sk2', 'r_sk3', 'r_sk4',
'r_th1', 'r_th2', 'r_th3', 'r_th4', 'l_asis', 'r_asis', 'l_psis', 'r_psis',
'm_ster1', 'm_ster2', 'm_ster3', 'l_sho', 'l_ua', 'l_elbl', 'l_frm', 'l_wrr',
'l_wru', 'l_hand', 'r_sho', 'r_ua', 'r_elbl', 'r_frm', 'r_wrr', 'r_wru', 'r_hand',
'lf_hd', 'rf_hd', 'lb_hd', 'rb_hd', 'start_1', 'start_2', 'end_1', 'end_2'
"l_toe",
"l_heel",
"l_ank",
"l_sk1",
"l_sk2",
"l_sk3",
"l_sk4",
"l_th1",
"l_th2",
"l_th3",
"l_th4",
"r_toe",
"r_heel",
"r_ank",
"r_sk1",
"r_sk2",
"r_sk3",
"r_sk4",
"r_th1",
"r_th2",
"r_th3",
"r_th4",
"l_asis",
"r_asis",
"l_psis",
"r_psis",
"m_ster1",
"m_ster2",
"m_ster3",
"l_sho",
"l_ua",
"l_elbl",
"l_frm",
"l_wrr",
"l_wru",
"l_hand",
"r_sho",
"r_ua",
"r_elbl",
"r_frm",
"r_wrr",
"r_wru",
"r_hand",
"lf_hd",
"rf_hd",
"lb_hd",
"rb_hd",
"start_1",
"start_2",
"end_1",
"end_2",
],
"imu": [
'head', 'sternum', 'left_upper_arm', 'left_fore_arm', 'right_upper_arm',
'right_fore_arm', 'pelvis', 'left_thigh', 'left_shank', 'left_foot',
'right_thigh', 'right_shank', 'right_foot', 'left_ankle', 'right_ankle'
]
"head",
"sternum",
"left_upper_arm",
"left_fore_arm",
"right_upper_arm",
"right_fore_arm",
"pelvis",
"left_thigh",
"left_shank",
"left_foot",
"right_thigh",
"right_shank",
"right_foot",
"left_ankle",
"right_ankle",
],
}


def fetch_dataset(
dataset_path: str | Path = Path(__file__).parent / "_keepcontrol",
) -> None:
Expand All @@ -41,7 +99,7 @@ def fetch_dataset(
if not any(dataset_path.iterdir()):
# Download the dataset using openneuro-py
openneuro.download(
dataset="ds005258", # this is the example Keep Control dataset on OpenNeuro, maintained by Julius Welzel
dataset="ds005258", # this is the example Keep Control dataset on OpenNeuro, maintained by Julius Welzel
target_dir=dataset_path,
)

Expand All @@ -53,6 +111,7 @@ def fetch_dataset(

return


def load_recording(
dataset_path: str | Path = Path(__file__).parent / "_keepcontrol",
id: str = "pp001",
Expand Down Expand Up @@ -81,14 +140,15 @@ def load_recording(
# check if task contains task or task- substring, if so replace it with ''
task = task.replace("task", "").replace("-", "")


# Put tracking systems in a list
if isinstance(tracking_systems, str):
tracking_systems = [tracking_systems]

# check if tracked points has been specified, if not use all tracked points
if tracked_points is None:
tracked_points = {tracksys: VALID_TRACKED_POINTS[tracksys] for tracksys in tracking_systems}
tracked_points = {
tracksys: VALID_TRACKED_POINTS[tracksys] for tracksys in tracking_systems
}
# Tracked points will be a dictionary mapping
# each tracking system to a list of tracked points of interest
if isinstance(tracked_points, str):
Expand All @@ -99,38 +159,58 @@ def load_recording(
# return error of some tracked_points are not valid
# print which of the specified tracked points are not valid
for tracksys in tracking_systems:
if not all(tracked_point in VALID_TRACKED_POINTS[tracksys] for tracked_point in tracked_points[tracksys]):
if not all(
tracked_point in VALID_TRACKED_POINTS[tracksys]
for tracked_point in tracked_points[tracksys]
):
print(f"Invalid tracked points for tracking system {tracksys}.")
print(f"Valid tracked points are: {VALID_TRACKED_POINTS[tracksys]}")
invalid_points = [tracked_point for tracked_point in tracked_points[tracksys] if tracked_point not in VALID_TRACKED_POINTS[tracksys]]
invalid_points = [
tracked_point
for tracked_point in tracked_points[tracksys]
if tracked_point not in VALID_TRACKED_POINTS[tracksys]
]
print(f"Invalid tracked points are: {invalid_points}")
return


# Load data and channels for each tracking system
data_dict, channels_dict = {}, {}
for tracksys in tracking_systems:
# Find avaliable files for the give ID and task and tracking system
file_name = list(dataset_path.glob(f"sub-{id}/motion/sub-{id}_task-{task}_tracksys-{tracksys}_*motion.tsv"))
file_name = list(
dataset_path.glob(
f"sub-{id}/motion/sub-{id}_task-{task}_tracksys-{tracksys}_*motion.tsv"
)
)
# check if file exists, if not print message to user and return
if not file_name:
print(f"No files found for ID {id}, task {task}, and tracking system {tracksys}.")
print(
f"No files found for ID {id}, task {task}, and tracking system {tracksys}."
)
return
# check if multiple files are found, if so print message to user and return
if len(file_name) > 1:
print(f"Multiple files found for ID {id}, task {task}, and tracking system {tracksys}.")
print(
f"Multiple files found for ID {id}, task {task}, and tracking system {tracksys}."
)
return

# Load the data and channels for the tracking system
df_data = pd.read_csv(file_name[0], sep="\t")
df_channels = pd.read_csv(file_name[0].parent / f"sub-{id}_task-{task}_tracksys-{tracksys}_channels.tsv", sep="\t")
df_channels = pd.read_csv(
file_name[0].parent
/ f"sub-{id}_task-{task}_tracksys-{tracksys}_channels.tsv",
sep="\t",
)

# filter the data and channels to only include the tracked points of interest
df_channels = df_channels[df_channels["tracked_point"].isin(tracked_points[tracksys])]
df_channels = df_channels[
df_channels["tracked_point"].isin(tracked_points[tracksys])
]

# only keep df_data columns that are in df_channels
col_names = df_channels["name"].values
df_data = df_data[col_names]
df_data = df_data[col_names]

# transform the data and channels into a dictionary for the KielMATRecording dataclass
data_dict[tracksys] = df_data
Expand All @@ -143,4 +223,4 @@ def load_recording(
recording.add_info("Subject", id)
recording.add_info("Task", task)

return recording
return recording
21 changes: 19 additions & 2 deletions kielmat/utils/kielmat_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,18 @@
}

# See: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/motion.html#restricted-keyword-list-for-channel-component
VALID_COMPONENT_TYPES = {"x", "y", "z", "quat_x", "quat_y", "quat_z", "quat_w", "n/a", "NaN", "nan"}
VALID_COMPONENT_TYPES = {
"x",
"y",
"z",
"quat_x",
"quat_y",
"quat_z",
"quat_w",
"n/a",
"NaN",
"nan",
}

# See https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files.html#participants-file
VALID_INFO_KEYS = {
Expand Down Expand Up @@ -95,7 +106,13 @@ def validate_channels(self):
raise TypeError(
f"Column 'name' in '{system_name}' must be of type string."
)
invalid_components = set([item for item in df["component"] if item not in VALID_COMPONENT_TYPES and not pd.isna(item)])
invalid_components = set(
[
item
for item in df["component"]
if item not in VALID_COMPONENT_TYPES and not pd.isna(item)
]
)
if invalid_components:
raise ValueError(
f"Column 'component' in '{system_name}' contains invalid values: {invalid_components}."
Expand Down

0 comments on commit 06b20b8

Please sign in to comment.