-
Notifications
You must be signed in to change notification settings - Fork 1
/
ecg_rpeaks_dl.py
312 lines (268 loc) · 11.5 KB
/
ecg_rpeaks_dl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
"""
R peaks detection using deep learning models for single-lead ECG signal
References:
-----------
[1] Cai, Wenjie, and Danqin Hu. "QRS complex detection using novel deep learning neural networks." IEEE Access (2020).
"""
import math
import os
from itertools import repeat
from numbers import Real
from typing import NoReturn, Optional, Sequence, Union
import numpy as np
from scipy.signal import resample_poly
try:
import biosppy.signals.ecg as BSE
except:
import references.biosppy.biosppy.signals.ecg as BSE
from utils import mask_to_intervals
from .ecg_rpeaks_dl_models import load_model
__all__ = [
"seq_lab_net_detect",
]
CNN_MODEL, CRNN_MODEL = load_model("keras_ecg_seq_lab_net")
def seq_lab_net_detect(sig: np.ndarray, fs: Real, correction: bool = False, **kwargs) -> np.ndarray:
"""finished, checked,
use model of entry 0416 of CPSC2019,
to detect R peaks in single-lead ECGs of arbitrary length
NOTE: `sig` should have units in mV, NOT in μV!
Parameters:
-----------
sig: ndarray,
the (raw) ECG signal of arbitrary length, with units in mV
fs: real number,
sampling frequency of `sig`
correction: bool, default False,
if True, correct rpeaks to local maximum in a small nbh
of rpeaks detected by DL model using `BSE.correct_rpeaks`
kwargs: dict,
optional key word arguments, including
- verbose, int, default 0,
print verbosity
- batch_size, int, default None,
batch size for feeding into the model
Returns:
--------
rpeaks: ndarray,
indices of rpeaks in `sig`
References:
-----------
[1] Cai, Wenjie, and Danqin Hu. "QRS complex detection using novel deep learning neural networks." IEEE Access (2020).
"""
verbose = kwargs.get("verbose", 0)
batch_size = kwargs.get("batch_size", None)
model_fs = 500
model_granularity = 8 # 1/8 times of model_fs
# pre-process
sig_rsmp = _seq_lab_net_pre_process(sig, verbose=verbose)
if fs != model_fs:
sig_rsmp = resample_poly(sig_rsmp, up=model_fs, down=int(fs))
max_single_batch_half_len = 10 * 60 * model_fs
if len(sig_rsmp) > 2 * max_single_batch_half_len:
if batch_size is None:
batch_size = 64
if verbose >= 1:
print(f"the signal is too long, hence split into segments for parallel computing of batch size {batch_size}")
if batch_size is not None:
model_input_len = 5000
half_overlap_len = 256 # approximately 0.5s, should be divisible by `model_granularity`
half_overlap_len_prob = half_overlap_len // model_granularity
overlap_len = 2 * half_overlap_len
forward_len = model_input_len - overlap_len
n_segs, residue = divmod(len(sig_rsmp) - overlap_len, forward_len)
if residue != 0:
sig_rsmp = np.append(sig_rsmp, np.zeros((forward_len - residue,)))
n_segs += 1
n_batches = math.ceil(n_segs / batch_size)
if verbose >= 2:
print(f"number of batches = {n_batches}")
prob = []
segs = list(range(n_segs))
for b_idx in range(n_batches):
# b_start = b_idx * batch_size * forward_len
b_start = b_idx * batch_size
b_segs = segs[b_start : b_start + batch_size]
b_input = np.vstack([sig_rsmp[idx * forward_len : idx * forward_len + model_input_len] for idx in b_segs]).reshape(
(-1, model_input_len, 1)
)
prob_cnn = CNN_MODEL.predict(b_input)
prob_crnn = CRNN_MODEL.predict(b_input)
b_prob = (prob_cnn[..., 0] + prob_crnn[..., 0]) / 2
b_prob = b_prob[..., half_overlap_len_prob:-half_overlap_len_prob]
prob += b_prob.flatten().tolist()
if b_idx == 0:
head_prob = (b_prob[0, :half_overlap_len_prob]).tolist()
if b_idx == n_batches - 1:
tail_prob = (b_prob[-1, -half_overlap_len_prob:]).tolist()
if verbose >= 1:
print(f"{b_idx+1}/{n_batches} batches", end="\r")
# prob, output from the for loop,
# is the array of probabilities for sig_rsmp[half_overlap_len: -half_overlap_len]
prob = list(repeat(0, half_overlap_len_prob)) + prob + list(repeat(0, half_overlap_len_prob))
# prob = head_prob + prob + tail_prob # NOTE: head and tail might not be trustable
prob = np.array(prob)
else:
prob_cnn = CNN_MODEL.predict(sig_rsmp.reshape((1, len(sig_rsmp), 1)))
prob_crnn = CRNN_MODEL.predict(sig_rsmp.reshape((1, len(sig_rsmp), 1)))
prob = ((prob_cnn + prob_crnn) / 2).squeeze()
# prob --> qrs mask --> qrs intervals --> rpeaks
rpeaks = _seq_lab_net_post_process(prob, 0.5, verbose=verbose)
# convert from resampled positions to original positions
rpeaks = (np.round((fs / model_fs) * rpeaks)).astype(int)
rpeaks = rpeaks[np.where(rpeaks < len(sig))[0]]
# adjust to the "true" rpeaks,
# i.e. the max in a small nbh of each element in `rpeaks`
if correction:
(rpeaks,) = BSE.correct_rpeaks(
signal=sig,
rpeaks=rpeaks,
sampling_rate=fs,
tol=0.05,
)
return rpeaks
def _seq_lab_net_pre_process(sig: np.ndarray, verbose: int = 0) -> np.ndarray:
"""partly finished, partly checked,
Parameters:
-----------
sig: ndarray,
the ECG signal to be pre-processed
verbose: int, default 0,
print verbosity
Returns:
--------
sig_processed: ndarray,
the processed ECG signal
"""
# Single towering spike whose voltage is more than 20 mV is examined
# and replaced by the normal sample immediately before it
sig_processed = _remove_spikes_naive(sig)
# TODO:
# To achieve better model generalization,
# the (local?) mean of signal values is subtracted for each recording
return sig_processed
def _seq_lab_net_post_process(
prob: np.ndarray,
prob_thr: float = 0.5,
duration_thr: int = 4 * 16,
dist_thr: Union[int, Sequence[int]] = 200,
verbose: int = 0,
) -> np.ndarray:
"""finished, checked,
convert the array of probability predictions into the array of indices of rpeaks
Parameters:
-----------
prob: ndarray,
the array of probabilities of qrs complex
prob_thr: float, default 0.5,
threshold of probability for predicting qrs complex
duration_thr: int, default 4*16,
minimum duration for a "true" qrs complex, units in ms
dist_thr: int or sequence of int, default 200,
if is sequence of int,
(0-th element). minimum distance for two consecutive qrs complexes, units in ms;
(1st element).(optional) maximum distance for checking missing qrs complexes, units in ms,
e.g. [200, 1200]
if is int, then is the case of (0-th element).
verbose: int, default 0,
print verbosity
Returns:
--------
rpeaks: ndarray,
indices of rpeaks in converted from the array `prob`
"""
model_fs = 500
model_spacing = 1000 / model_fs # units in ms
model_granularity = 8 # 1/8 times of model_fs
_prob = prob.squeeze()
assert _prob.ndim == 1, "only support single record processing, batch processing not supported!"
# prob --> qrs mask --> qrs intervals --> rpeaks
mask = (_prob >= prob_thr).astype(int)
qrs_intervals = mask_to_intervals(mask, 1)
# threshold of 64 ms for the duration of clustering positive samples
# is set to eliminate some wrong predictions
_duration_thr = duration_thr / model_spacing / model_granularity
# should be 8 * (itv[0]+itv[1]) / 2
rpeaks = (model_granularity // 2) * np.array([itv[0] + itv[1] for itv in qrs_intervals if itv[1] - itv[0] >= _duration_thr])
if verbose >= 3:
print(f"raw rpeak predictions = {rpeaks.tolist()}")
_dist_thr = [dist_thr] if isinstance(dist_thr, int) else dist_thr
assert len(_dist_thr) <= 2
# filter out those rpeaks that are too close to each other
check = True
dist_thr_inds = _dist_thr[0] / model_spacing
while check:
check = False
rpeaks_diff = np.diff(rpeaks)
for r in range(len(rpeaks_diff)):
if rpeaks_diff[r] < dist_thr_inds: # 200 ms
prev_r_ind = int(rpeaks[r] / model_granularity) # ind in _prob
next_r_ind = int(rpeaks[r + 1] / model_granularity) # ind in _prob
if _prob[prev_r_ind] > _prob[next_r_ind]:
del_ind = r + 1
else:
del_ind = r
rpeaks = np.delete(rpeaks, del_ind)
check = True
if verbose >= 2:
print(f"the {del_ind}-th R peak was removed since too close to another R peak")
break
if len(_dist_thr) == 1:
return rpeaks
check = True
# further search should be performed to locate where the
# distances are greater than 1200 ms between adjacent QRS complexes
# if there exists at least one point that is great than 0.5,
# the threshold of the duration of clustering positive samples is reduced by 16 ms
# and this process will continue until a new QRS candidate is found
# or the threshold decreases to zero
check = True
# TODO: parallel the following block
# CAUTION !!!
# this part is extremely slow in some cases (long duration and low SNR)
dist_thr_inds = _dist_thr[1] / model_spacing
while check:
check = False
rpeaks_diff = np.diff(rpeaks)
for r in range(len(rpeaks_diff)):
if rpeaks_diff[r] >= dist_thr_inds: # 1200 ms
prev_r_ind = int(rpeaks[r] / model_granularity) # ind in _prob
next_r_ind = int(rpeaks[r + 1] / model_granularity) # ind in _prob
prev_qrs = [itv for itv in qrs_intervals if itv[0] <= prev_r_ind <= itv[1]][0]
next_qrs = [itv for itv in qrs_intervals if itv[0] <= next_r_ind <= itv[1]][0]
check_itv = [prev_qrs[1], next_qrs[0]]
l_new_itv = mask_to_intervals(mask[check_itv[0] : check_itv[1]], 1)
if len(l_new_itv) == 0:
continue
l_new_itv = [[itv[0] + check_itv[0], itv[1] + check_itv[0]] for itv in l_new_itv]
new_itv = max(l_new_itv, key=lambda itv: itv[1] - itv[0])
new_max_prob = (_prob[new_itv[0] : new_itv[1]]).max()
for itv in l_new_itv:
itv_prob = (_prob[itv[0] : itv[1]]).max()
if itv[1] - itv[0] == new_itv[1] - new_itv[0] and itv_prob > new_max_prob:
new_itv = itv
new_max_prob = itv_prob
rpeaks = np.insert(rpeaks, r + 1, 4 * (new_itv[0] + new_itv[1]))
check = True
if verbose >= 2:
print(f"found back an rpeak inside the {r}-th RR interval")
break
return rpeaks
def _remove_spikes_naive(sig: np.ndarray) -> np.ndarray:
"""finished, checked,
remove `spikes` from `sig` using a naive method proposed in entry 0416 of CPSC2019
`spikes` here refers to abrupt large bumps with (abs) value larger than 20 mV,
do NOT confuse with `spikes` in paced rhythm
Parameters:
-----------
sig: ndarray,
single-lead ECG signal with potential spikes
Returns:
--------
filtered_sig: ndarray,
ECG signal with `spikes` removed
"""
b = list(filter(lambda k: k > 0, np.argwhere(np.abs(sig) > 20).squeeze(-1)))
filtered_sig = sig.copy()
for k in b:
filtered_sig[k] = filtered_sig[k - 1]
return filtered_sig