Skip to content

Commit

Permalink
Merge pull request #321 from hofaflo/no-sklearn
Browse files Browse the repository at this point in the history
Remove dependency on scikit-learn
  • Loading branch information
tompollard authored Sep 13, 2021
2 parents ae67d09 + c5960e2 commit b9c61aa
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 43 deletions.
6 changes: 1 addition & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class Mock(MagicMock):
def __getattr__(cls, name):
return MagicMock()

MOCK_MODULES = ['numpy', 'matplotlib', 'matplotlib.pyplot', 'pandas', 'scipy',
'sklearn', 'sklearn.preprocessing']
MOCK_MODULES = ['numpy', 'matplotlib', 'matplotlib.pyplot', 'pandas', 'scipy']
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)

# -- General configuration ------------------------------------------------
Expand Down Expand Up @@ -177,6 +176,3 @@ def __getattr__(cls, name):
author, 'wfdb', 'One line description of project.',
'Miscellaneous'),
]



1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ nose==1.3.7
numpy==1.18.5
pandas==1.0.3
requests==2.23.0
scikit-learn==0.22.2.post1
scipy==1.4.1
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
'numpy>=1.10.1',
'pandas>=0.17.0',
'requests>=2.8.1',
'scikit-learn>=0.18',
'scipy>=0.17.0',
],

Expand Down
20 changes: 19 additions & 1 deletion wfdb/processing/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_filter_gain(b, a, f_gain, fs):
The frequency at which to calculate the gain.
fs : int, float, optional
The sampling frequency of the system.
Returns
-------
gain : int, float
Expand All @@ -230,3 +230,21 @@ def get_filter_gain(b, a, f_gain, fs):
gain = abs(h[ind])

return gain


def normalize(X):
"""
Scale input vector to unit norm (vector length).
Parameters
----------
X : ndarray
The vector to normalize.
Returns
-------
ndarray
The normalized vector.
"""
return X / np.linalg.norm(X)
68 changes: 33 additions & 35 deletions wfdb/processing/qrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

import numpy as np
from scipy import signal
from sklearn.preprocessing import normalize

from wfdb.processing.basic import get_filter_gain
from wfdb.processing.basic import get_filter_gain, normalize
from wfdb.processing.peaks import find_local_peaks
from wfdb.io.record import Record


class XQRS(object):
"""
The QRS detector class for the XQRS algorithm. The `XQRS.Conf`
class is the configuration class that stores initial parameters
The QRS detector class for the XQRS algorithm. The `XQRS.Conf`
class is the configuration class that stores initial parameters
for the detection. The `XQRS.detect` method runs the detection algorithm.
The process works as follows:
Expand Down Expand Up @@ -85,7 +84,7 @@ class Conf(object):
----------
hr_init : int, float, optional
Initial heart rate in beats per minute. Used for calculating
recent R-R intervals.
recent R-R intervals.
hr_max : int, float, optional
Hard maximum heart rate between two beats, in beats per
minute. Used for refractory period.
Expand All @@ -104,13 +103,13 @@ class Conf(object):
ref_period : int, float, optional
The QRS refractory period.
t_inspect_period : int, float, optional
The period below which a potential QRS complex is
inspected to see if it is a T-wave.
The period below which a potential QRS complex is inspected to
see if it is a T-wave. Leave as 0 for no T-wave inspection.
"""
def __init__(self, hr_init=75, hr_max=200, hr_min=25, qrs_width=0.1,
qrs_thr_init=0.13, qrs_thr_min=0, ref_period=0.2,
t_inspect_period=0.36):
t_inspect_period=0):
if hr_min < 0:
raise ValueError("'hr_min' must be >= 0")

Expand All @@ -134,7 +133,7 @@ def __init__(self, hr_init=75, hr_max=200, hr_min=25, qrs_width=0.1,
def _set_conf(self):
"""
Set configuration parameters from the Conf object into the detector
object. Time values are converted to samples, and amplitude values
object. Time values are converted to samples, and amplitude values
are in mV.
Parameters
Expand Down Expand Up @@ -288,10 +287,10 @@ def _learn_init_params(self, n_calib_beats=8):

# Question: should the signal be squared? Case for inverse QRS
# complexes
sig_segment = normalize((self.sig_f[i - self.qrs_radius:
i + self.qrs_radius]).reshape(-1, 1), axis=0)
sig_segment = normalize(self.sig_f[i - self.qrs_radius:
i + self.qrs_radius])

xcorr = np.correlate(sig_segment[:, 0], ricker_wavelet[:,0])
xcorr = np.correlate(sig_segment, ricker_wavelet[:,0])

# Classify as QRS if xcorr is large enough
if xcorr > 0.6 and i-last_qrs_ind > self.rr_min:
Expand Down Expand Up @@ -470,15 +469,15 @@ def _update_qrs(self, peak_num, backsearch=False):
The peak number of the MWI signal where the QRS is detected.
backsearch: bool, optional
Whether the QRS was found via backsearch.
Returns
-------
N/A
"""
i = self.peak_inds_i[peak_num]

# Update recent R-R interval if the beat is consecutive (do this
# Update recent R-R interval if the beat is consecutive (do this
# before updating self.last_qrs_ind)
rr_new = i - self.last_qrs_ind
if rr_new < self.rr_max:
Expand Down Expand Up @@ -514,7 +513,7 @@ def _is_twave(self, peak_num):
----------
peak_num : int
The peak number of the MWI signal where the QRS is detected.
Returns
-------
bool
Expand All @@ -530,8 +529,7 @@ def _is_twave(self, peak_num):

# Get half the QRS width of the signal to the left.
# Should this be squared?
sig_segment = normalize((self.sig_f[i - self.qrs_radius:i]
).reshape(-1, 1), axis=0)
sig_segment = normalize(self.sig_f[i - self.qrs_radius:i])
last_qrs_segment = self.sig_f[self.last_qrs_ind - self.qrs_radius:
self.last_qrs_ind]

Expand Down Expand Up @@ -901,7 +899,7 @@ def __init__(self, fs, adc_gain, hr=75,
class Peak(object):
"""
Holds all of the peak information for the QRS object.
Attributes
----------
peak_time : int, float
Expand All @@ -923,7 +921,7 @@ def __init__(self, peak_time, peak_amp, peak_type):
class Annotation(object):
"""
Holds all of the annotation information for the QRS object.
Attributes
----------
ann_time : int, float
Expand Down Expand Up @@ -1160,8 +1158,8 @@ def qfv_put(self, t, v):

def sm(self, at_t):
"""
Implements a trapezoidal low pass (smoothing) filter (with a gain
of 4*smdt) applied to input signal sig before the QRS matched
Implements a trapezoidal low pass (smoothing) filter (with a gain
of 4*smdt) applied to input signal sig before the QRS matched
filter qf(). Before attempting to 'rewind' by more than BUFLN-smdt
samples, reset smt and smt0.
Expand Down Expand Up @@ -1220,7 +1218,7 @@ def qf(self):
N/A
"""
# Do this first, to ensure that all of the other smoothed values
# Do this first, to ensure that all of the other smoothed values
# needed below are in the buffer
dv2 = self.sm(self.t + self.c.dt4)
dv2 -= self.smv_at(self.t - self.c.dt4)
Expand Down Expand Up @@ -1302,17 +1300,17 @@ def add_peak(peak_time, peak_amp, peak_type):

def peaktype(p):
"""
The neighborhood consists of all other peaks within rrmin.
Normally, "most prominent" is equivalent to "largest in
amplitude", but this is not always true. For example, consider
three consecutive peaks a, b, c such that a and b share a
neighborhood, b and c share a neighborhood, but a and c do not;
and suppose that amp(a) > amp(b) > amp(c). In this case, if
The neighborhood consists of all other peaks within rrmin.
Normally, "most prominent" is equivalent to "largest in
amplitude", but this is not always true. For example, consider
three consecutive peaks a, b, c such that a and b share a
neighborhood, b and c share a neighborhood, but a and c do not;
and suppose that amp(a) > amp(b) > amp(c). In this case, if
there are no other peaks, a is the most prominent peak in the (a, b)
neighborhood. Since b is thus identified as a non-prominent peak,
c becomes the most prominent peak in the (b, c) neighborhood.
This is necessary to permit detection of low-amplitude beats that
closely precede or follow beats with large secondary peaks (as,
neighborhood. Since b is thus identified as a non-prominent peak,
c becomes the most prominent peak in the (b, c) neighborhood.
This is necessary to permit detection of low-amplitude beats that
closely precede or follow beats with large secondary peaks (as,
for example, in R-on-T PVCs).
Parameters
Expand All @@ -1323,7 +1321,7 @@ def peaktype(p):
Returns
-------
int
Whether the input peak is the most prominent peak in its
Whether the input peak is the most prominent peak in its
neighborhood (1) or not (2).
"""
Expand Down Expand Up @@ -1534,8 +1532,8 @@ def gqrs_detect(sig=None, fs=None, d_sig=None, adc_gain=None, adc_zero=None,
"""
Detect QRS locations in a single channel ecg. Functionally, a direct port
of the GQRS algorithm from the original WFDB package. Accepts either a
physical signal, or a digital signal with known adc_gain and adc_zero. See
the notes below for a summary of the program. This algorithm is not being
physical signal, or a digital signal with known adc_gain and adc_zero. See
the notes below for a summary of the program. This algorithm is not being
developed/supported.
Parameters
Expand Down

0 comments on commit b9c61aa

Please sign in to comment.