-
Notifications
You must be signed in to change notification settings - Fork 131
/
post_process.py
149 lines (132 loc) · 6.21 KB
/
post_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""The post processing files for caluclating heart rate using FFT or peak detection.
The file also includes helper funcs such as detrend, power2db etc.
"""
import numpy as np
import scipy
import scipy.io
from scipy.signal import butter
from scipy.sparse import spdiags
from copy import deepcopy
def _next_power_of_2(x):
"""Calculate the nearest power of 2."""
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def _detrend(input_signal, lambda_value):
"""Detrend PPG signal."""
signal_length = input_signal.shape[0]
# observation matrix
H = np.identity(signal_length)
ones = np.ones(signal_length)
minus_twos = -2 * np.ones(signal_length)
diags_data = np.array([ones, minus_twos, ones])
diags_index = np.array([0, 1, 2])
D = spdiags(diags_data, diags_index,
(signal_length - 2), signal_length).toarray()
detrended_signal = np.dot(
(H - np.linalg.inv(H + (lambda_value ** 2) * np.dot(D.T, D))), input_signal)
return detrended_signal
def power2db(mag):
"""Convert power to db."""
return 10 * np.log10(mag)
def _calculate_fft_hr(ppg_signal, fs=60, low_pass=0.75, high_pass=2.5):
"""Calculate heart rate based on PPG using Fast Fourier transform (FFT)."""
ppg_signal = np.expand_dims(ppg_signal, 0)
N = _next_power_of_2(ppg_signal.shape[1])
f_ppg, pxx_ppg = scipy.signal.periodogram(ppg_signal, fs=fs, nfft=N, detrend=False)
fmask_ppg = np.argwhere((f_ppg >= low_pass) & (f_ppg <= high_pass))
mask_ppg = np.take(f_ppg, fmask_ppg)
mask_pxx = np.take(pxx_ppg, fmask_ppg)
fft_hr = np.take(mask_ppg, np.argmax(mask_pxx, 0))[0] * 60
return fft_hr
def _calculate_peak_hr(ppg_signal, fs):
"""Calculate heart rate based on PPG using peak detection."""
ppg_peaks, _ = scipy.signal.find_peaks(ppg_signal)
hr_peak = 60 / (np.mean(np.diff(ppg_peaks)) / fs)
return hr_peak
def _compute_macc(pred_signal, gt_signal):
"""Calculate maximum amplitude of cross correlation (MACC) by computing correlation at all time lags.
Args:
pred_ppg_signal(np.array): predicted PPG signal
label_ppg_signal(np.array): ground truth, label PPG signal
Returns:
MACC(float): Maximum Amplitude of Cross-Correlation
"""
pred = deepcopy(pred_signal)
gt = deepcopy(gt_signal)
pred = np.squeeze(pred)
gt = np.squeeze(gt)
min_len = np.min((len(pred), len(gt)))
pred = pred[:min_len]
gt = gt[:min_len]
lags = np.arange(0, len(pred)-1, 1)
tlcc_list = []
for lag in lags:
cross_corr = np.abs(np.corrcoef(
pred, np.roll(gt, lag))[0][1])
tlcc_list.append(cross_corr)
macc = max(tlcc_list)
return macc
def _calculate_SNR(pred_ppg_signal, hr_label, fs=30, low_pass=0.75, high_pass=2.5):
"""Calculate SNR as the ratio of the area under the curve of the frequency spectrum around the first and second harmonics
of the ground truth HR frequency to the area under the curve of the remainder of the frequency spectrum, from 0.75 Hz
to 2.5 Hz.
Args:
pred_ppg_signal(np.array): predicted PPG signal
label_ppg_signal(np.array): ground truth, label PPG signal
fs(int or float): sampling rate of the video
Returns:
SNR(float): Signal-to-Noise Ratio
"""
# Get the first and second harmonics of the ground truth HR in Hz
first_harmonic_freq = hr_label / 60
second_harmonic_freq = 2 * first_harmonic_freq
deviation = 6 / 60 # 6 beats/min converted to Hz (1 Hz = 60 beats/min)
# Calculate FFT
pred_ppg_signal = np.expand_dims(pred_ppg_signal, 0)
N = _next_power_of_2(pred_ppg_signal.shape[1])
f_ppg, pxx_ppg = scipy.signal.periodogram(pred_ppg_signal, fs=fs, nfft=N, detrend=False)
# Calculate the indices corresponding to the frequency ranges
idx_harmonic1 = np.argwhere((f_ppg >= (first_harmonic_freq - deviation)) & (f_ppg <= (first_harmonic_freq + deviation)))
idx_harmonic2 = np.argwhere((f_ppg >= (second_harmonic_freq - deviation)) & (f_ppg <= (second_harmonic_freq + deviation)))
idx_remainder = np.argwhere((f_ppg >= low_pass) & (f_ppg <= high_pass) \
& ~((f_ppg >= (first_harmonic_freq - deviation)) & (f_ppg <= (first_harmonic_freq + deviation))) \
& ~((f_ppg >= (second_harmonic_freq - deviation)) & (f_ppg <= (second_harmonic_freq + deviation))))
# Select the corresponding values from the periodogram
pxx_ppg = np.squeeze(pxx_ppg)
pxx_harmonic1 = pxx_ppg[idx_harmonic1]
pxx_harmonic2 = pxx_ppg[idx_harmonic2]
pxx_remainder = pxx_ppg[idx_remainder]
# Calculate the signal power
signal_power_hm1 = np.sum(pxx_harmonic1**2)
signal_power_hm2 = np.sum(pxx_harmonic2**2)
signal_power_rem = np.sum(pxx_remainder**2)
# Calculate the SNR as the ratio of the areas
if not signal_power_rem == 0: # catches divide by 0 runtime warning
SNR = power2db((signal_power_hm1 + signal_power_hm2) / signal_power_rem)
else:
SNR = 0
return SNR
def calculate_metric_per_video(predictions, labels, fs=30, diff_flag=True, use_bandpass=True, hr_method='FFT'):
"""Calculate video-level HR and SNR"""
if diff_flag: # if the predictions and labels are 1st derivative of PPG signal.
predictions = _detrend(np.cumsum(predictions), 100)
labels = _detrend(np.cumsum(labels), 100)
else:
predictions = _detrend(predictions, 100)
labels = _detrend(labels, 100)
if use_bandpass:
# bandpass filter between [0.75, 2.5] Hz
# equals [45, 150] beats per min
[b, a] = butter(1, [0.75 / fs * 2, 2.5 / fs * 2], btype='bandpass')
predictions = scipy.signal.filtfilt(b, a, np.double(predictions))
labels = scipy.signal.filtfilt(b, a, np.double(labels))
macc = _compute_macc(predictions, labels)
if hr_method == 'FFT':
hr_pred = _calculate_fft_hr(predictions, fs=fs)
hr_label = _calculate_fft_hr(labels, fs=fs)
elif hr_method == 'Peak':
hr_pred = _calculate_peak_hr(predictions, fs=fs)
hr_label = _calculate_peak_hr(labels, fs=fs)
else:
raise ValueError('Please use FFT or Peak to calculate your HR.')
SNR = _calculate_SNR(predictions, hr_label, fs=fs)
return hr_label, hr_pred, SNR, macc