From a183c4d52d9b8838be43774db35fbd0ce4b2b089 Mon Sep 17 00:00:00 2001 From: cx1111 Date: Thu, 30 Jun 2022 15:39:08 -0700 Subject: [PATCH] Add functions to determine sample ranges of signals in multi-segment records (#403) * Add contained_ranges method to calculate sample ranges that contain a signal, in multi-segment records * Add contained_combined_ranges function * Add tests for new ranges functions * Fix logic to account for empty segments and derive some fields in MultiRecord constructor --- tests/test_multi_record.py | 113 +++++++++++++++++++++++++++++++++++++ wfdb/io/_header.py | 112 +++++++++++++++++++++++++++++++++++- wfdb/io/record.py | 26 ++++----- wfdb/io/util.py | 22 +++++++- 4 files changed, 257 insertions(+), 16 deletions(-) create mode 100644 tests/test_multi_record.py diff --git a/tests/test_multi_record.py b/tests/test_multi_record.py new file mode 100644 index 00000000..683c2c44 --- /dev/null +++ b/tests/test_multi_record.py @@ -0,0 +1,113 @@ +import wfdb + + +class TestMultiRecordRanges: + """ + Test logic that deduces relevant segments/ranges for given signals. + """ + + def test_contained_ranges_simple_cases(self): + record = wfdb.MultiRecord( + segments=[ + wfdb.Record(sig_name=["I", "II"], sig_len=5), + wfdb.Record(sig_name=["I", "III"], sig_len=10), + ], + ) + + assert record.contained_ranges("I") == [(0, 15)] + assert record.contained_ranges("II") == [(0, 5)] + assert record.contained_ranges("III") == [(5, 15)] + + def test_contained_ranges_variable_layout(self): + record = wfdb.rdheader( + "sample-data/multi-segment/s00001/s00001-2896-10-10-00-31", + rd_segments=True, + ) + + assert record.contained_ranges("II") == [ + (3261, 10136), + (4610865, 10370865), + (10528365, 14518365), + ] + assert record.contained_ranges("V") == [ + (3261, 918261), + (920865, 4438365), + (4610865, 10370865), + (10528365, 14518365), + ] + assert record.contained_ranges("MCL1") == [ + (10136, 918261), + (920865, 4438365), + ] + assert record.contained_ranges("ABP") == [ + (14428365, 14450865), + (14458365, 14495865), + ] + + def test_contained_ranges_fixed_layout(self): + record = wfdb.rdheader( + "sample-data/multi-segment/041s/041s", + rd_segments=True, + ) + + for sig_name in record.sig_name: + assert record.contained_ranges(sig_name) == [(0, 2000)] + + def test_contained_combined_ranges_simple_cases(self): + record = wfdb.MultiRecord( + segments=[ + wfdb.Record(sig_name=["I", "II", "V"], sig_len=5), + wfdb.Record(sig_name=["I", "III", "V"], sig_len=10), + wfdb.Record(sig_name=["I", "II", "V"], sig_len=20), + ], + ) + + assert record.contained_combined_ranges(["I", "II"]) == [ + (0, 5), + (15, 35), + ] + assert record.contained_combined_ranges(["II", "III"]) == [] + assert record.contained_combined_ranges(["I", "III"]) == [(5, 15)] + assert record.contained_combined_ranges(["I", "II", "V"]) == [ + (0, 5), + (15, 35), + ] + + def test_contained_combined_ranges_variable_layout(self): + record = wfdb.rdheader( + "sample-data/multi-segment/s00001/s00001-2896-10-10-00-31", + rd_segments=True, + ) + + assert record.contained_combined_ranges(["II", "V"]) == [ + (3261, 10136), + (4610865, 10370865), + (10528365, 14518365), + ] + assert record.contained_combined_ranges(["II", "MCL1"]) == [] + assert record.contained_combined_ranges(["II", "ABP"]) == [ + (14428365, 14450865), + (14458365, 14495865), + ] + assert record.contained_combined_ranges(["II", "V", "ABP"]) == [ + (14428365, 14450865), + (14458365, 14495865), + ] + assert ( + record.contained_combined_ranges(["II", "V", "MCL1", "ABP"]) == [] + ) + + def test_contained_combined_ranges_variable_layout(self): + record = wfdb.rdheader( + "sample-data/multi-segment/041s/041s", + rd_segments=True, + ) + + for sig_1 in record.sig_name: + for sig_2 in record.sig_name: + if sig_1 == sig_2: + continue + + assert record.contained_combined_ranges([sig_1, sig_2]) == [ + (0, 2000) + ] diff --git a/wfdb/io/_header.py b/wfdb/io/_header.py index 15c8065a..142a69f7 100644 --- a/wfdb/io/_header.py +++ b/wfdb/io/_header.py @@ -1,6 +1,6 @@ import datetime import re -from typing import List, Tuple +from typing import Collection, List, Tuple import numpy as np import pandas as pd @@ -858,7 +858,7 @@ def get_sig_name(self): """ if self.segments is None: raise Exception( - "The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rsegment_fieldsments=True" + "The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rd_segments=True" ) if self.layout == "fixed": @@ -871,6 +871,114 @@ def get_sig_name(self): return sig_name + def contained_ranges(self, sig_name: str) -> List[Tuple[int, int]]: + """ + Given a signal name, return the sample ranges that contain signal values, + relative to the start of the full record. Does not account for NaNs/missing + values. + + This function is mainly useful for variable layout records, but can also be + used for fixed-layout records. Only works if the headers from the individual + segment records have already been read in. + + Parameters + ---------- + sig_name : str + The name of the signal to query. + + Returns + ------- + ranges : List[Tuple[int, int]] + Tuple pairs which specify thee sample ranges in which the signal is contained. + The second value of each tuple pair will be one beyond the signal index. + eg. A length 1000 signal would generate a tuple of: (0, 1000), allowing + selection using signal[0:1000]. + + """ + if self.segments is None: + raise Exception( + "The MultiRecord's segments must be read in before this method is called. ie. Call rdheader() with rd_segments=True" + ) + ranges = [] + seg_start = 0 + + range_start = None + + # TODO: Add shortcut for fixed-layout records + + # Cannot process segments only because missing segments are None + # and do not contain length information. + for seg_num in range(self.n_seg): + seg_len = self.seg_len[seg_num] + segment = self.segments[seg_num] + + if seg_len == 0: + continue + + # Open signal range + if ( + range_start is None + and segment is not None + and sig_name in segment.sig_name + ): + range_start = seg_start + # Close signal range + elif range_start is not None and ( + segment is None or sig_name not in segment.sig_name + ): + ranges.append((range_start, seg_start)) + range_start = None + + seg_start += seg_len + + # Account for final segment + if range_start is not None: + ranges.append((range_start, seg_start)) + + return ranges + + def contained_combined_ranges( + self, + sig_names: Collection[str], + ) -> List[Tuple[int, int]]: + """ + Given a collection of signal name, return the sample ranges that + contain all of the specified signals, relative to the start of the + full record. Does not account for NaNs/missing values. + + This function is mainly useful for variable layout records, but can also be + used for fixed-layout records. Only works if the headers from the individual + segment records have already been read in. + + Parameters + ---------- + sig_names : List[str] + The names of the signals to query. + + Returns + ------- + ranges : List[Tuple[int, int]] + Tuple pairs which specify thee sample ranges in which the signal is contained. + The second value of each tuple pair will be one beyond the signal index. + eg. A length 1000 signal would generate a tuple of: (0, 1000), allowing + selection using signal[0:1000]. + + """ + # TODO: Add shortcut for fixed-layout records + + if len(sig_names) == 0: + return [] + + combined_ranges = self.contained_ranges(sig_names[0]) + + if len(sig_names) > 1: + for name in sig_names[1:]: + combined_ranges = util.overlapping_ranges( + combined_ranges, self.contained_ranges(name) + ) + + return combined_ranges + def wfdb_strptime(time_string: str) -> datetime.time: """ diff --git a/wfdb/io/record.py b/wfdb/io/record.py index 6e564917..8d096ab4 100644 --- a/wfdb/io/record.py +++ b/wfdb/io/record.py @@ -1080,8 +1080,8 @@ class MultiRecord(BaseRecord, _header.MultiHeaderMixin): `datetime.combine(base_date, base_time)`. seg_name : str, optional The name of the segment. - seg_len : int, optional - The length of the segment. + seg_len : List[int], optional + The length of each segment. comments : list, optional A list of string comments to be written to the header file. sig_name : str, optional @@ -1144,6 +1144,11 @@ def __init__( self.seg_len = seg_len self.sig_segments = sig_segments + if segments: + self.n_seg = len(segments) + if not seg_len: + self.seg_len = [segment.sig_len for segment in segments] + def wrsamp(self, write_dir=""): """ Write a multi-segment header, along with headers and dat files @@ -1184,33 +1189,28 @@ def _check_segment_cohesion(self): if self.n_seg != len(self.segments): raise ValueError("Length of segments must match the 'n_seg' field") - for i in range(n_seg): - s = self.segments[i] + for seg_num, segment in enumerate(self.segments): # If segment 0 is a layout specification record, check that its file names are all == '~'' - if i == 0 and self.seg_len[0] == 0: - for file_name in s.file_name: + if seg_num == 0 and self.seg_len[0] == 0: + for file_name in segment.file_name: if file_name != "~": raise ValueError( "Layout specification records must have all file_names named '~'" ) # Sampling frequencies must all match the one in the master header - if s.fs != self.fs: + if segment.fs != self.fs: raise ValueError( "The 'fs' in each segment must match the overall record's 'fs'" ) # Check the signal length of the segment against the corresponding seg_len field - if s.sig_len != self.seg_len[i]: + if segment.sig_len != self.seg_len[seg_num]: raise ValueError( - "The signal length of segment " - + str(i) - + " does not match the corresponding segment length" + f"The signal length of segment {seg_num} does not match the corresponding segment length" ) - totalsig_len = totalsig_len + getattr(s, "sig_len") - # No need to check the sum of sig_lens from each segment object against sig_len # Already effectively done it when checking sum(seg_len) against sig_len diff --git a/wfdb/io/util.py b/wfdb/io/util.py index 24e4495c..12ecde33 100644 --- a/wfdb/io/util.py +++ b/wfdb/io/util.py @@ -4,7 +4,7 @@ import math import os -from typing import Sequence +from typing import Sequence, Tuple def lines_to_file(file_name: str, write_dir: str, lines: Sequence[str]): @@ -99,3 +99,23 @@ def upround(x, base): """ return base * math.ceil(float(x) / base) + + +def overlapping_ranges( + ranges_1: Tuple[int, int], ranges_2: Tuple[int, int] +) -> Tuple[int, int]: + """ + Given two collections of integer ranges, return a list of ranges + in which both input inputs overlap. + + From: https://stackoverflow.com/q/40367461 + + Slightly modified so that if the end of one range exactly equals + the start of the other range, no overlap would be returned. + """ + return [ + (max(first[0], second[0]), min(first[1], second[1])) + for first in ranges_1 + for second in ranges_2 + if max(first[0], second[0]) < min(first[1], second[1]) + ]