Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions to determine sample ranges of signals in multi-segment records #403

Merged
merged 4 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions tests/test_multi_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import wfdb


class TestMultiRecordRanges:
"""
Test logic that deduces relevant segments/ranges for given signals.
"""

def test_contained_ranges_simple_cases(self):
record = wfdb.MultiRecord(
segments=[
wfdb.Record(sig_name=["I", "II"], sig_len=5),
wfdb.Record(sig_name=["I", "III"], sig_len=10),
],
)

assert record.contained_ranges("I") == [(0, 15)]
assert record.contained_ranges("II") == [(0, 5)]
assert record.contained_ranges("III") == [(5, 15)]

def test_contained_ranges_variable_layout(self):
record = wfdb.rdheader(
"sample-data/multi-segment/s00001/s00001-2896-10-10-00-31",
rd_segments=True,
)

assert record.contained_ranges("II") == [
(3261, 10136),
(4610865, 10370865),
(10528365, 14518365),
]
assert record.contained_ranges("V") == [
(3261, 918261),
(920865, 4438365),
(4610865, 10370865),
(10528365, 14518365),
]
assert record.contained_ranges("MCL1") == [
(10136, 918261),
(920865, 4438365),
]
assert record.contained_ranges("ABP") == [
(14428365, 14450865),
(14458365, 14495865),
]

def test_contained_ranges_fixed_layout(self):
record = wfdb.rdheader(
"sample-data/multi-segment/041s/041s",
rd_segments=True,
)

for sig_name in record.sig_name:
assert record.contained_ranges(sig_name) == [(0, 2000)]

def test_contained_combined_ranges_simple_cases(self):
record = wfdb.MultiRecord(
segments=[
wfdb.Record(sig_name=["I", "II", "V"], sig_len=5),
wfdb.Record(sig_name=["I", "III", "V"], sig_len=10),
wfdb.Record(sig_name=["I", "II", "V"], sig_len=20),
],
)

assert record.contained_combined_ranges(["I", "II"]) == [
(0, 5),
(15, 35),
]
assert record.contained_combined_ranges(["II", "III"]) == []
assert record.contained_combined_ranges(["I", "III"]) == [(5, 15)]
assert record.contained_combined_ranges(["I", "II", "V"]) == [
(0, 5),
(15, 35),
]

def test_contained_combined_ranges_variable_layout(self):
record = wfdb.rdheader(
"sample-data/multi-segment/s00001/s00001-2896-10-10-00-31",
rd_segments=True,
)

assert record.contained_combined_ranges(["II", "V"]) == [
(3261, 10136),
(4610865, 10370865),
(10528365, 14518365),
]
assert record.contained_combined_ranges(["II", "MCL1"]) == []
assert record.contained_combined_ranges(["II", "ABP"]) == [
(14428365, 14450865),
(14458365, 14495865),
]
assert record.contained_combined_ranges(["II", "V", "ABP"]) == [
(14428365, 14450865),
(14458365, 14495865),
]
assert (
record.contained_combined_ranges(["II", "V", "MCL1", "ABP"]) == []
)

def test_contained_combined_ranges_variable_layout(self):
record = wfdb.rdheader(
"sample-data/multi-segment/041s/041s",
rd_segments=True,
)

for sig_1 in record.sig_name:
for sig_2 in record.sig_name:
if sig_1 == sig_2:
continue

assert record.contained_combined_ranges([sig_1, sig_2]) == [
(0, 2000)
]
112 changes: 110 additions & 2 deletions wfdb/io/_header.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import re
from typing import List, Tuple
from typing import Collection, List, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -858,7 +858,7 @@ def get_sig_name(self):
"""
if self.segments is None:
raise Exception(
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rsegment_fieldsments=True"
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rd_segments=True"
)

if self.layout == "fixed":
Expand All @@ -871,6 +871,114 @@ def get_sig_name(self):

return sig_name

def contained_ranges(self, sig_name: str) -> List[Tuple[int, int]]:
"""
Given a signal name, return the sample ranges that contain signal values,
relative to the start of the full record. Does not account for NaNs/missing
values.

This function is mainly useful for variable layout records, but can also be
used for fixed-layout records. Only works if the headers from the individual
segment records have already been read in.

Parameters
----------
sig_name : str
The name of the signal to query.

Returns
-------
ranges : List[Tuple[int, int]]
Tuple pairs which specify thee sample ranges in which the signal is contained.
The second value of each tuple pair will be one beyond the signal index.
eg. A length 1000 signal would generate a tuple of: (0, 1000), allowing
selection using signal[0:1000].

"""
if self.segments is None:
raise Exception(
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rd_segments=True"
)
ranges = []
seg_start = 0

range_start = None

# TODO: Add shortcut for fixed-layout records

# Cannot process segments only because missing segments are None
# and do not contain length information.
for seg_num in range(self.n_seg):
seg_len = self.seg_len[seg_num]
segment = self.segments[seg_num]

if seg_len == 0:
continue

# Open signal range
if (
range_start is None
and segment is not None
and sig_name in segment.sig_name
):
range_start = seg_start
# Close signal range
elif range_start is not None and (
segment is None or sig_name not in segment.sig_name
):
ranges.append((range_start, seg_start))
range_start = None

seg_start += seg_len

# Account for final segment
if range_start is not None:
ranges.append((range_start, seg_start))

return ranges

def contained_combined_ranges(
self,
sig_names: Collection[str],
) -> List[Tuple[int, int]]:
"""
Given a collection of signal name, return the sample ranges that
contain all of the specified signals, relative to the start of the
full record. Does not account for NaNs/missing values.

This function is mainly useful for variable layout records, but can also be
used for fixed-layout records. Only works if the headers from the individual
segment records have already been read in.

Parameters
----------
sig_names : List[str]
The names of the signals to query.

Returns
-------
ranges : List[Tuple[int, int]]
Tuple pairs which specify thee sample ranges in which the signal is contained.
The second value of each tuple pair will be one beyond the signal index.
eg. A length 1000 signal would generate a tuple of: (0, 1000), allowing
selection using signal[0:1000].

"""
# TODO: Add shortcut for fixed-layout records

if len(sig_names) == 0:
return []

combined_ranges = self.contained_ranges(sig_names[0])

if len(sig_names) > 1:
for name in sig_names[1:]:
combined_ranges = util.overlapping_ranges(
combined_ranges, self.contained_ranges(name)
)

return combined_ranges


def wfdb_strptime(time_string: str) -> datetime.time:
"""
Expand Down
26 changes: 13 additions & 13 deletions wfdb/io/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,8 +1040,8 @@ class MultiRecord(BaseRecord, _header.MultiHeaderMixin):
`datetime.combine(base_date, base_time)`.
seg_name : str, optional
The name of the segment.
seg_len : int, optional
The length of the segment.
seg_len : List[int], optional
The length of each segment.
comments : list, optional
A list of string comments to be written to the header file.
sig_name : str, optional
Expand Down Expand Up @@ -1104,6 +1104,11 @@ def __init__(
self.seg_len = seg_len
self.sig_segments = sig_segments

if segments:
self.n_seg = len(segments)
if not seg_len:
self.seg_len = [segment.sig_len for segment in segments]

def wrsamp(self, write_dir=""):
"""
Write a multi-segment header, along with headers and dat files
Expand Down Expand Up @@ -1144,33 +1149,28 @@ def _check_segment_cohesion(self):
if self.n_seg != len(self.segments):
raise ValueError("Length of segments must match the 'n_seg' field")

for i in range(n_seg):
s = self.segments[i]
for seg_num, segment in enumerate(self.segments):

# If segment 0 is a layout specification record, check that its file names are all == '~''
if i == 0 and self.seg_len[0] == 0:
for file_name in s.file_name:
if seg_num == 0 and self.seg_len[0] == 0:
for file_name in segment.file_name:
if file_name != "~":
raise ValueError(
"Layout specification records must have all file_names named '~'"
)

# Sampling frequencies must all match the one in the master header
if s.fs != self.fs:
if segment.fs != self.fs:
raise ValueError(
"The 'fs' in each segment must match the overall record's 'fs'"
)

# Check the signal length of the segment against the corresponding seg_len field
if s.sig_len != self.seg_len[i]:
if segment.sig_len != self.seg_len[seg_num]:
raise ValueError(
"The signal length of segment "
+ str(i)
+ " does not match the corresponding segment length"
f"The signal length of segment {seg_num} does not match the corresponding segment length"
)

totalsig_len = totalsig_len + getattr(s, "sig_len")

# No need to check the sum of sig_lens from each segment object against sig_len
# Already effectively done it when checking sum(seg_len) against sig_len

Expand Down
22 changes: 21 additions & 1 deletion wfdb/io/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import os

from typing import Sequence
from typing import Sequence, Tuple


def lines_to_file(file_name: str, write_dir: str, lines: Sequence[str]):
Expand Down Expand Up @@ -99,3 +99,23 @@ def upround(x, base):

"""
return base * math.ceil(float(x) / base)


def overlapping_ranges(
ranges_1: Tuple[int, int], ranges_2: Tuple[int, int]
) -> Tuple[int, int]:
"""
Given two collections of integer ranges, return a list of ranges
in which both input inputs overlap.

From: https://stackoverflow.com/q/40367461

Slightly modified so that if the end of one range exactly equals
the start of the other range, no overlap would be returned.
"""
return [
(max(first[0], second[0]), min(first[1], second[1]))
for first in ranges_1
for second in ranges_2
if max(first[0], second[0]) < min(first[1], second[1])
]