Skip to content

Commit

Permalink
Add functions to determine sample ranges of signals in multi-segment …
Browse files Browse the repository at this point in the history
…records (#403)

* Add contained_ranges method to calculate sample ranges that contain a signal, in multi-segment records

* Add contained_combined_ranges function

* Add tests for new ranges functions

* Fix logic to account for empty segments and derive some fields in MultiRecord constructor
  • Loading branch information
cx1111 authored Jun 30, 2022
1 parent 14df878 commit a183c4d
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 16 deletions.
113 changes: 113 additions & 0 deletions tests/test_multi_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import wfdb


class TestMultiRecordRanges:
"""
Test logic that deduces relevant segments/ranges for given signals.
"""

def test_contained_ranges_simple_cases(self):
record = wfdb.MultiRecord(
segments=[
wfdb.Record(sig_name=["I", "II"], sig_len=5),
wfdb.Record(sig_name=["I", "III"], sig_len=10),
],
)

assert record.contained_ranges("I") == [(0, 15)]
assert record.contained_ranges("II") == [(0, 5)]
assert record.contained_ranges("III") == [(5, 15)]

def test_contained_ranges_variable_layout(self):
record = wfdb.rdheader(
"sample-data/multi-segment/s00001/s00001-2896-10-10-00-31",
rd_segments=True,
)

assert record.contained_ranges("II") == [
(3261, 10136),
(4610865, 10370865),
(10528365, 14518365),
]
assert record.contained_ranges("V") == [
(3261, 918261),
(920865, 4438365),
(4610865, 10370865),
(10528365, 14518365),
]
assert record.contained_ranges("MCL1") == [
(10136, 918261),
(920865, 4438365),
]
assert record.contained_ranges("ABP") == [
(14428365, 14450865),
(14458365, 14495865),
]

def test_contained_ranges_fixed_layout(self):
record = wfdb.rdheader(
"sample-data/multi-segment/041s/041s",
rd_segments=True,
)

for sig_name in record.sig_name:
assert record.contained_ranges(sig_name) == [(0, 2000)]

def test_contained_combined_ranges_simple_cases(self):
record = wfdb.MultiRecord(
segments=[
wfdb.Record(sig_name=["I", "II", "V"], sig_len=5),
wfdb.Record(sig_name=["I", "III", "V"], sig_len=10),
wfdb.Record(sig_name=["I", "II", "V"], sig_len=20),
],
)

assert record.contained_combined_ranges(["I", "II"]) == [
(0, 5),
(15, 35),
]
assert record.contained_combined_ranges(["II", "III"]) == []
assert record.contained_combined_ranges(["I", "III"]) == [(5, 15)]
assert record.contained_combined_ranges(["I", "II", "V"]) == [
(0, 5),
(15, 35),
]

def test_contained_combined_ranges_variable_layout(self):
record = wfdb.rdheader(
"sample-data/multi-segment/s00001/s00001-2896-10-10-00-31",
rd_segments=True,
)

assert record.contained_combined_ranges(["II", "V"]) == [
(3261, 10136),
(4610865, 10370865),
(10528365, 14518365),
]
assert record.contained_combined_ranges(["II", "MCL1"]) == []
assert record.contained_combined_ranges(["II", "ABP"]) == [
(14428365, 14450865),
(14458365, 14495865),
]
assert record.contained_combined_ranges(["II", "V", "ABP"]) == [
(14428365, 14450865),
(14458365, 14495865),
]
assert (
record.contained_combined_ranges(["II", "V", "MCL1", "ABP"]) == []
)

def test_contained_combined_ranges_variable_layout(self):
record = wfdb.rdheader(
"sample-data/multi-segment/041s/041s",
rd_segments=True,
)

for sig_1 in record.sig_name:
for sig_2 in record.sig_name:
if sig_1 == sig_2:
continue

assert record.contained_combined_ranges([sig_1, sig_2]) == [
(0, 2000)
]
112 changes: 110 additions & 2 deletions wfdb/io/_header.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import re
from typing import List, Tuple
from typing import Collection, List, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -858,7 +858,7 @@ def get_sig_name(self):
"""
if self.segments is None:
raise Exception(
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rsegment_fieldsments=True"
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rd_segments=True"
)

if self.layout == "fixed":
Expand All @@ -871,6 +871,114 @@ def get_sig_name(self):

return sig_name

def contained_ranges(self, sig_name: str) -> List[Tuple[int, int]]:
"""
Given a signal name, return the sample ranges that contain signal values,
relative to the start of the full record. Does not account for NaNs/missing
values.
This function is mainly useful for variable layout records, but can also be
used for fixed-layout records. Only works if the headers from the individual
segment records have already been read in.
Parameters
----------
sig_name : str
The name of the signal to query.
Returns
-------
ranges : List[Tuple[int, int]]
Tuple pairs which specify thee sample ranges in which the signal is contained.
The second value of each tuple pair will be one beyond the signal index.
eg. A length 1000 signal would generate a tuple of: (0, 1000), allowing
selection using signal[0:1000].
"""
if self.segments is None:
raise Exception(
"The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rd_segments=True"
)
ranges = []
seg_start = 0

range_start = None

# TODO: Add shortcut for fixed-layout records

# Cannot process segments only because missing segments are None
# and do not contain length information.
for seg_num in range(self.n_seg):
seg_len = self.seg_len[seg_num]
segment = self.segments[seg_num]

if seg_len == 0:
continue

# Open signal range
if (
range_start is None
and segment is not None
and sig_name in segment.sig_name
):
range_start = seg_start
# Close signal range
elif range_start is not None and (
segment is None or sig_name not in segment.sig_name
):
ranges.append((range_start, seg_start))
range_start = None

seg_start += seg_len

# Account for final segment
if range_start is not None:
ranges.append((range_start, seg_start))

return ranges

def contained_combined_ranges(
self,
sig_names: Collection[str],
) -> List[Tuple[int, int]]:
"""
Given a collection of signal name, return the sample ranges that
contain all of the specified signals, relative to the start of the
full record. Does not account for NaNs/missing values.
This function is mainly useful for variable layout records, but can also be
used for fixed-layout records. Only works if the headers from the individual
segment records have already been read in.
Parameters
----------
sig_names : List[str]
The names of the signals to query.
Returns
-------
ranges : List[Tuple[int, int]]
Tuple pairs which specify thee sample ranges in which the signal is contained.
The second value of each tuple pair will be one beyond the signal index.
eg. A length 1000 signal would generate a tuple of: (0, 1000), allowing
selection using signal[0:1000].
"""
# TODO: Add shortcut for fixed-layout records

if len(sig_names) == 0:
return []

combined_ranges = self.contained_ranges(sig_names[0])

if len(sig_names) > 1:
for name in sig_names[1:]:
combined_ranges = util.overlapping_ranges(
combined_ranges, self.contained_ranges(name)
)

return combined_ranges


def wfdb_strptime(time_string: str) -> datetime.time:
"""
Expand Down
26 changes: 13 additions & 13 deletions wfdb/io/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,8 @@ class MultiRecord(BaseRecord, _header.MultiHeaderMixin):
`datetime.combine(base_date, base_time)`.
seg_name : str, optional
The name of the segment.
seg_len : int, optional
The length of the segment.
seg_len : List[int], optional
The length of each segment.
comments : list, optional
A list of string comments to be written to the header file.
sig_name : str, optional
Expand Down Expand Up @@ -1144,6 +1144,11 @@ def __init__(
self.seg_len = seg_len
self.sig_segments = sig_segments

if segments:
self.n_seg = len(segments)
if not seg_len:
self.seg_len = [segment.sig_len for segment in segments]

def wrsamp(self, write_dir=""):
"""
Write a multi-segment header, along with headers and dat files
Expand Down Expand Up @@ -1184,33 +1189,28 @@ def _check_segment_cohesion(self):
if self.n_seg != len(self.segments):
raise ValueError("Length of segments must match the 'n_seg' field")

for i in range(n_seg):
s = self.segments[i]
for seg_num, segment in enumerate(self.segments):

# If segment 0 is a layout specification record, check that its file names are all == '~''
if i == 0 and self.seg_len[0] == 0:
for file_name in s.file_name:
if seg_num == 0 and self.seg_len[0] == 0:
for file_name in segment.file_name:
if file_name != "~":
raise ValueError(
"Layout specification records must have all file_names named '~'"
)

# Sampling frequencies must all match the one in the master header
if s.fs != self.fs:
if segment.fs != self.fs:
raise ValueError(
"The 'fs' in each segment must match the overall record's 'fs'"
)

# Check the signal length of the segment against the corresponding seg_len field
if s.sig_len != self.seg_len[i]:
if segment.sig_len != self.seg_len[seg_num]:
raise ValueError(
"The signal length of segment "
+ str(i)
+ " does not match the corresponding segment length"
f"The signal length of segment {seg_num} does not match the corresponding segment length"
)

totalsig_len = totalsig_len + getattr(s, "sig_len")

# No need to check the sum of sig_lens from each segment object against sig_len
# Already effectively done it when checking sum(seg_len) against sig_len

Expand Down
22 changes: 21 additions & 1 deletion wfdb/io/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import os

from typing import Sequence
from typing import Sequence, Tuple


def lines_to_file(file_name: str, write_dir: str, lines: Sequence[str]):
Expand Down Expand Up @@ -99,3 +99,23 @@ def upround(x, base):
"""
return base * math.ceil(float(x) / base)


def overlapping_ranges(
ranges_1: Tuple[int, int], ranges_2: Tuple[int, int]
) -> Tuple[int, int]:
"""
Given two collections of integer ranges, return a list of ranges
in which both input inputs overlap.
From: https://stackoverflow.com/q/40367461
Slightly modified so that if the end of one range exactly equals
the start of the other range, no overlap would be returned.
"""
return [
(max(first[0], second[0]), min(first[1], second[1]))
for first in ranges_1
for second in ranges_2
if max(first[0], second[0]) < min(first[1], second[1])
]

0 comments on commit a183c4d

Please sign in to comment.