Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] PPG plots improvements #883

Merged
merged 14 commits into from
Aug 22, 2023
12 changes: 9 additions & 3 deletions docs/functions/ppg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ Preprocessing
"""""""""""""
.. autofunction:: neurokit2.ppg.ppg_clean

*ppg_findpeaks()*
*ppg_peaks()*
"""""""""""""""""
.. autofunction:: neurokit2.ppg.ppg_findpeaks
.. autofunction:: neurokit2.ppg.ppg_peaks


Analysis
Expand All @@ -42,9 +42,15 @@ Analysis
.. autofunction:: neurokit2.ppg.ppg_intervalrelated


Miscellaneous
^^^^^^^^^^^^^^^^
*ppg_findpeaks()*
""""""""""""""""""""""""
.. autofunction:: neurokit2.ecg.ppg_findpeaks


*Any function appearing below this point is not explicitly part of the documentation and should be added. Please open an issue if there is one.*

.. automodule:: neurokit2.ppg
:members:
:exclude-members: ppg_process, ppg_analyze, ppg_simulate, ppg_plot, ppg_clean, ppg_findpeaks, ppg_rate, ppg_eventrelated, ppg_intervalrelated
:exclude-members: ppg_process, ppg_analyze, ppg_simulate, ppg_plot, ppg_clean, ppg_peaks, ppg_findpeaks, ppg_rate, ppg_eventrelated, ppg_intervalrelated
22 changes: 15 additions & 7 deletions neurokit2/ecg/ecg_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@


def ecg_peaks(
ecg_cleaned, sampling_rate=1000, method="neurokit", correct_artifacts=False, **kwargs
ecg_cleaned,
sampling_rate=1000,
method="neurokit",
correct_artifacts=False,
**kwargs
):
"""**Find R-peaks in an ECG signal**

Find R-peaks in an ECG signal using the specified method. The method accepts unfiltered ECG
signals as input, although it is expected that a filtered (cleaned) ECG will result in better
results.
Find R-peaks in an ECG signal using the specified method. You can pass an unfiltered ECG
signals as input, but typically a filtered ECG (cleaned using ``ecg_clean()``) will result in
better results.

Different algorithms for peak-detection include:

Expand Down Expand Up @@ -49,7 +53,7 @@ def ecg_peaks(
ecg_cleaned : Union[list, np.array, pd.Series]
The cleaned ECG channel as returned by ``ecg_clean()``.
sampling_rate : int
The sampling frequency of ``ecg_signal`` (in Hz, i.e., samples/second). Defaults to 1000.
The sampling frequency of ``ecg_cleaned`` (in Hz, i.e., samples/second). Defaults to 1000.
method : string
The algorithm to be used for R-peak detection.
correct_artifacts : bool
Expand Down Expand Up @@ -250,7 +254,9 @@ def ecg_peaks(
engineering & technology, 43(3), 173-181.

"""
rpeaks = ecg_findpeaks(ecg_cleaned, sampling_rate=sampling_rate, method=method, **kwargs)
rpeaks = ecg_findpeaks(
ecg_cleaned, sampling_rate=sampling_rate, method=method, **kwargs
)

if correct_artifacts:
_, rpeaks = signal_fixpeaks(
Expand All @@ -259,7 +265,9 @@ def ecg_peaks(

rpeaks = {"ECG_R_Peaks": rpeaks}

instant_peaks = signal_formatpeaks(rpeaks, desired_length=len(ecg_cleaned), peak_indices=rpeaks)
instant_peaks = signal_formatpeaks(
rpeaks, desired_length=len(ecg_cleaned), peak_indices=rpeaks
)
signals = instant_peaks
info = rpeaks
info["sampling_rate"] = sampling_rate # Add sampling rate in dict info
Expand Down
72 changes: 47 additions & 25 deletions neurokit2/ecg/ecg_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from ..epochs import epochs_to_df
from ..signal import signal_fixpeaks
from ..stats import rescale
from .ecg_segment import ecg_segment
from .ecg_segment import _ecg_segment_plot, ecg_segment


def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=None, show_type="default"):
def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=1000, show_type="default"):
"""**Visualize ECG data**

Plot ECG signals and R-peaks.
Expand All @@ -24,9 +24,7 @@ def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=None, show_type="default"):
The samples at which the R-peak occur. Dict returned by
``ecg_process()``. Defaults to ``None``.
sampling_rate : int
The sampling frequency of the ECG (in Hz, i.e., samples/second). Needs to be supplied if the
data should be plotted over time in seconds. Otherwise the data is plotted over samples.
Defaults to ``None``. Must be specified to plot artifacts.
The sampling frequency of ``ecg_cleaned`` (in Hz, i.e., samples/second). Defaults to 1000.
show_type : str
Visualize the ECG data with ``"default"`` or visualize artifacts thresholds with
``"artifacts"`` produced by ``ecg_fixpeaks()``, or ``"full"`` to visualize both.
Expand Down Expand Up @@ -77,8 +75,12 @@ def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=None, show_type="default"):
# Prepare figure and set axes.
if show_type in ["default", "full"]:
if sampling_rate is not None:
x_axis = np.linspace(0, ecg_signals.shape[0] / sampling_rate, ecg_signals.shape[0])
gs = matplotlib.gridspec.GridSpec(2, 2, width_ratios=[1 - 1 / np.pi, 1 / np.pi])
x_axis = np.linspace(
0, ecg_signals.shape[0] / sampling_rate, ecg_signals.shape[0]
)
gs = matplotlib.gridspec.GridSpec(
2, 2, width_ratios=[1 - 1 / np.pi, 1 / np.pi]
)
fig = plt.figure(constrained_layout=False)
ax0 = fig.add_subplot(gs[0, :-1])
ax1 = fig.add_subplot(gs[1, :-1])
Expand Down Expand Up @@ -138,40 +140,56 @@ def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=None, show_type="default"):
handles, labels = ax0.get_legend_handles_labels()
order = [2, 0, 1, 3]
ax0.legend(
[handles[idx] for idx in order], [labels[idx] for idx in order], loc="upper right"
[handles[idx] for idx in order],
[labels[idx] for idx in order],
loc="upper right",
)

# Plot heart rate.
ax1.set_title("Heart Rate")
ax1.set_ylabel("Beats per minute (bpm)")

ax1.plot(x_axis, ecg_signals["ECG_Rate"], color="#FF5722", label="Rate", linewidth=1.5)
ax1.plot(
x_axis,
ecg_signals["ECG_Rate"],
color="#FF5722",
label="Rate",
linewidth=1.5,
)
rate_mean = ecg_signals["ECG_Rate"].mean()
ax1.axhline(y=rate_mean, label="Mean", linestyle="--", color="#FF9800")

ax1.legend(loc="upper right")

# Plot individual heart beats.
if sampling_rate is not None:
ax2.set_title("Individual Heart Beats")
# TODO: how can we directly insert the figure from ecg_segment?
# fig_beats = ecg_segment(
# ecg_signals["ECG_Clean"], peaks, sampling_rate, show="return"
# )
# ax2 = fig_beats.axes[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now ecg_segment(..., show="return") returns a Figure, but how can we insert that figure into the axis so that we don't have to re-plot it? 🤔 @danibene any ideas from the top of your head?

matplotlib is the worse

Copy link
Collaborator

@danibene danibene Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not off the top of my head, if I understood your question correctly it seems like it's not possible to insert axes from another figure according to the answer here: https://stackoverflow.com/questions/66369315/how-to-assign-axes-from-one-figure-to-axes-from-another-figure

Copy link
Member Author

@DominiqueMakowski DominiqueMakowski Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok thanks, that was helpful, went with the ax kwarg similar to other packages


heartbeats = ecg_segment(ecg_signals["ECG_Clean"], peaks, sampling_rate)
heartbeats = epochs_to_df(heartbeats)
# Recreate individual heartbeat figure
ax2.set_title("Individual Heart Beats")

heartbeats_pivoted = heartbeats.pivot(index="Time", columns="Label", values="Signal")
heartbeats = ecg_segment(ecg_signals["ECG_Clean"], peaks, sampling_rate)
heartbeats = epochs_to_df(heartbeats)

heartbeats_pivoted = heartbeats.pivot(
index="Time", columns="Label", values="Signal"
)

ax2.plot(heartbeats_pivoted)
ax2.plot(heartbeats_pivoted)

cmap = iter(
plt.cm.YlOrRd(
np.linspace(0, 1, num=int(heartbeats["Label"].nunique()))
) # pylint: disable=E1101
) # Aesthetics of heart beats
cmap = iter(
plt.cm.YlOrRd(
np.linspace(0, 1, num=int(heartbeats["Label"].nunique()))
) # pylint: disable=E1101
) # Aesthetics of heart beats

lines = []
for x, color in zip(heartbeats_pivoted, cmap):
(line,) = ax2.plot(heartbeats_pivoted[x], color=color)
lines.append(line)
lines = []
for x, color in zip(heartbeats_pivoted, cmap):
(line,) = ax2.plot(heartbeats_pivoted[x], color=color)
lines.append(line)

# Plot artifacts
if show_type in ["artifacts", "full"]:
Expand All @@ -185,5 +203,9 @@ def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=None, show_type="default"):
_, rpeaks = ecg_peaks(ecg_signals["ECG_Clean"], sampling_rate=sampling_rate)

fig = signal_fixpeaks(
rpeaks, sampling_rate=sampling_rate, iterative=True, show=True, method="Kubios"
rpeaks,
sampling_rate=sampling_rate,
iterative=True,
show=True,
method="Kubios",
)
65 changes: 46 additions & 19 deletions neurokit2/ecg/ecg_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def ecg_segment(ecg_cleaned, rpeaks=None, sampling_rate=1000, show=False):
rpeaks : dict
The samples at which the R-peaks occur. Dict returned by ``ecg_peaks()``. Defaults to ``None``.
sampling_rate : int
The sampling frequency of ``ecg_signal`` (in Hz, i.e., samples/second). Defaults to 1000.
The sampling frequency of ``ecg_cleaned`` (in Hz, i.e., samples/second). Defaults to 1000.
show : bool
If ``True``, will return a plot of heartbeats. Defaults to ``False``.

Expand All @@ -41,14 +41,18 @@ def ecg_segment(ecg_cleaned, rpeaks=None, sampling_rate=1000, show=False):
ecg = nk.ecg_simulate(duration=15, sampling_rate=1000, heart_rate=80, noise = 0.05)
@savefig p_ecg_segment.png scale=100%
qrs_epochs = nk.ecg_segment(ecg, rpeaks=None, sampling_rate=1000, show=True)
@suppress
plt.close()

"""
# Sanitize inputs
if rpeaks is None:
_, rpeaks = ecg_peaks(ecg_cleaned, sampling_rate=sampling_rate, correct_artifacts=True)
_, rpeaks = ecg_peaks(
ecg_cleaned, sampling_rate=sampling_rate, correct_artifacts=True
)
rpeaks = rpeaks["ECG_R_Peaks"]

epochs_start, epochs_end = _ecg_segment_window(
epochs_start, epochs_end, average_hr = _ecg_segment_window(
rpeaks=rpeaks, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)
)
heartbeats = epochs_create(
Expand All @@ -64,34 +68,57 @@ def ecg_segment(ecg_cleaned, rpeaks=None, sampling_rate=1000, show=False):
after_last_index = heartbeats[last_heartbeat_key]["Index"] < len(ecg_cleaned)
heartbeats[last_heartbeat_key].loc[after_last_index, "Signal"] = np.nan

if show:
heartbeats_plot = epochs_to_df(heartbeats)
heartbeats_pivoted = heartbeats_plot.pivot(index="Time", columns="Label", values="Signal")
plt.plot(heartbeats_pivoted)
plt.xlabel("Time (s)")
plt.title("Individual Heart Beats")
cmap = iter(
plt.cm.YlOrRd(np.linspace(0, 1, num=int(heartbeats_plot["Label"].nunique())))
) # pylint: disable=no-member
lines = []
for x, color in zip(heartbeats_pivoted, cmap):
(line,) = plt.plot(heartbeats_pivoted[x], color=color)
lines.append(line)
if show is not False:
fig = _ecg_segment_plot(heartbeats, ytitle="ECG", heartrate=average_hr)
if show == "return":
return fig

return heartbeats


def _ecg_segment_window(heart_rate=None, rpeaks=None, sampling_rate=1000, desired_length=None):
# =============================================================================
# Internals
# =============================================================================
def _ecg_segment_plot(heartbeats, ytitle="ECG", heartrate=0):
df = epochs_to_df(heartbeats)
# Average heartbeat
mean_heartbeat = df.drop(["Index", "Label"], axis=1).groupby("Time").mean()
df_pivoted = df.pivot(index="Time", columns="Label", values="Signal")

# Prepare plot
fig = plt.figure()

plt.title(f"Individual Heart Beats (average heart rate: {heartrate:0.1f} bpm)")
plt.xlabel("Time (s)")
plt.ylabel(ytitle)

# Add Vertical line at 0
plt.axvline(x=0, color="grey", linestyle="--")

# Plot average heartbeat
plt.plot(mean_heartbeat.index, mean_heartbeat, color="red", linewidth=10)

# Plot all heartbeats
plt.plot(df_pivoted, color="grey", linewidth=2 / 3)

return fig


def _ecg_segment_window(
heart_rate=None, rpeaks=None, sampling_rate=1000, desired_length=None
):
# Extract heart rate
if heart_rate is not None:
heart_rate = np.mean(heart_rate)
if rpeaks is not None:
heart_rate = np.mean(
signal_rate(rpeaks, sampling_rate=sampling_rate, desired_length=desired_length)
signal_rate(
rpeaks, sampling_rate=sampling_rate, desired_length=desired_length
)
)

# Modulator
# Note: this is based on quick internal testing but could be improved
m = heart_rate / 60

# Window
Expand All @@ -104,4 +131,4 @@ def _ecg_segment_window(heart_rate=None, rpeaks=None, sampling_rate=1000, desire
epochs_start = epochs_start - c
epochs_end = epochs_end + c

return epochs_start, epochs_end
return epochs_start, epochs_end, heart_rate
2 changes: 2 additions & 0 deletions neurokit2/ppg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .ppg_findpeaks import ppg_findpeaks
from .ppg_intervalrelated import ppg_intervalrelated
from .ppg_methods import ppg_methods
from .ppg_peaks import ppg_peaks
from .ppg_plot import ppg_plot
from .ppg_process import ppg_process
from .ppg_simulate import ppg_simulate
Expand All @@ -16,6 +17,7 @@
"ppg_simulate",
"ppg_clean",
"ppg_findpeaks",
"ppg_peaks",
"ppg_rate",
"ppg_process",
"ppg_plot",
Expand Down
16 changes: 12 additions & 4 deletions neurokit2/ppg/ppg_findpeaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
from ..signal import signal_smooth


def ppg_findpeaks(ppg_cleaned, sampling_rate=1000, method="elgendi", show=False, **kwargs):
def ppg_findpeaks(
ppg_cleaned, sampling_rate=1000, method="elgendi", show=False, **kwargs
):
"""**Find systolic peaks in a photoplethysmogram (PPG) signal**

Low-level function used by :func:`ppg_peaks` to identify peaks in a PPG signal using a
different set of algorithms. Use the main function and see its documentation for details.

Parameters
----------
ppg_cleaned : Union[list, np.array, pd.Series]
Expand Down Expand Up @@ -71,7 +76,9 @@ def ppg_findpeaks(ppg_cleaned, sampling_rate=1000, method="elgendi", show=False,
elif method in ["msptd", "bishop2018", "bishop"]:
peaks, _ = _ppg_findpeaks_bishop(ppg_cleaned, show=show, **kwargs)
else:
raise ValueError("`method` not found. Must be one of the following: 'elgendi', 'bishop'.")
raise ValueError(
"`method` not found. Must be one of the following: 'elgendi', 'bishop'."
)

# Prepare output.
info = {"PPG_Peaks": peaks}
Expand Down Expand Up @@ -130,12 +137,13 @@ def _ppg_findpeaks_elgendi(

# Identify systolic peaks within waves (ignore waves that are too short).
num_waves = min(beg_waves.size, end_waves.size)
min_len = int(np.rint(peakwindow * sampling_rate)) # this is threshold 2 in the paper
min_len = int(
np.rint(peakwindow * sampling_rate)
) # this is threshold 2 in the paper
min_delay = int(np.rint(mindelay * sampling_rate))
peaks = [0]

for i in range(num_waves):

beg = beg_waves[i]
end = end_waves[i]
len_wave = end - beg
Expand Down
Loading
Loading