-
Notifications
You must be signed in to change notification settings - Fork 0
/
tools_signal.py
executable file
·187 lines (161 loc) · 6.74 KB
/
tools_signal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
-----------------------------------------------------------------------
Harmoni: a Novel Method for Eliminating Spurious Neuronal Interactions due to the Harmonic Components in Neuronal Data
Mina Jamshidi Idaji, Juanli Zhang, Tilman Stephani, Guido Nolte, Klaus-Robert Mueller, Arno Villringer, Vadim V. Nikulin
https://doi.org/10.1101/2021.10.06.463319
-----------------------------------------------------------------------
(c) Mina Jamshidi ([email protected]) @ Neurolgy Dept, MPI CBS, 2021
https://github.com/minajamshidi
(c) please cite the above paper in case of using this code for your research
License: MIT License
-----------------------------------------------------------------------
"""
import numpy as np
import matplotlib.pyplot as plt
# -------------------------------- -------------------------------- --------------------------------
# general
# -------------------------------- -------------------------------- --------------------------------
def dB(data, coeff=10):
return coeff * np.log10(data)
def dBinv(data, coeff=10):
return 10 ** (data / 10)
def zero_pad_to_pow2(x, axis=1):
"""
for fast computation of fft, zeros pad the signals to the next power of two
:param x: [n_signals x n_samples]
:param axis
:return: zero-padded signal
"""
n_samp = x.shape[axis]
n_sig = x.shape[1-axis]
n_zp = int(2 ** np.ceil(np.log2(n_samp))) - n_samp
zp = np.zeros_like(x)
zp = zp[:n_zp] if axis == 0 else zp[:, :n_zp]
y = np.append(x, zp, axis=axis)
return y, n_zp
# -------------------------------- -------------------------------- --------------------------------
# frequency domain and complex signals
# -------------------------------- -------------------------------- --------------------------------
def fft_(x, fs, axis=1, n_fft=None):
from scipy.fftpack import fft
if np.iscomplexobj(x):
x = np.real(x)
if x.ndim == 1:
x = x.reshape((1, x.shape[0])) if axis == 1 else x.reshape((x.shape[0], 1))
n_sample = x.shape[1]
n_fft = int(2 ** np.ceil(np.log2(n_sample))) if n_fft is None else n_fft
x_f = fft(x, n_fft)
freq = np.arange(0, fs / 2, fs / n_fft)
n_fft2 = int(n_fft / 2)
x_f = x_f[0, : n_fft2]
return freq, x_f
def plot_fft(x, fs, axis=1, n_fft=None):
freq, x_f = fft_(x, fs, axis=axis, n_fft=n_fft)
xf_abs = np.abs(x_f)
plt.plot(freq, xf_abs.ravel())
plt.title('Magnitude of FFT')
plt.grid()
def hilbert_(x, axis=1):
"""
computes fast hilbert transform by zero-padding the signal to a length of power of 2.
:param x: array_like
Signal data. Must be real.
:param axis: the axis along which the hilbert transform is computed, default=1
:return: x_h : analytic signal of x
"""
if np.iscomplexobj(x):
return x
from scipy.signal import hilbert
if len(x.shape) == 1:
x = x[np.newaxis, :] if axis == 1 else x[:, np.newaxis]
x_zp, n_zp = zero_pad_to_pow2(x, axis=axis)
x_zp = np.real(x_zp)
x_h = hilbert(x_zp, axis=axis)
if n_zp > 0:
x_h = x_h[:, :-n_zp] if axis == 1 else x_h[:-n_zp, :]
return x_h
def psd(data, fs, f_max=None, overlap_perc=0.5, freq_res=0.5, axis=1, plot=True, dB1=True,
fig='new', interactivePlot=True, clab=None):
"""
plots the spectrum of the input signal
:param data: ndarray [n_chan x n_samples]
data array . can be multi-channel
:param fs: sampling frequency
:param f_max: maximum frequency in the plotted spectrum
:param overlap_perc: overlap percentage of the sliding windows in welch method
:param freq_res: frequency resolution, in Hz
:return: no output, plots the spectrum
"""
from scipy.signal import welch
if np.iscomplexobj(data):
data = np.real(data)
if data.ndim == 1:
axis = 0
nfft = 2 ** np.ceil(np.log2(fs / freq_res))
noverlap = np.floor(overlap_perc * nfft)
f, pxx = welch(data, fs=fs, nfft=nfft, nperseg=nfft, noverlap=noverlap, axis=axis)
if f_max is not None:
indices = {axis: f <= f_max}
ix = tuple(indices.get(dim, slice(None)) for dim in range(pxx.ndim))
pxx = pxx[ix]
f = f[f <= f_max]
if plot:
if fig == 'new':
fig = plt.figure()
ax = plt.subplot(111)
else:
fig, ax = fig[0], fig[1]
if dB1:
line = ax.plot(f, dB(pxx.T), lw=1, picker=1)
else:
line = ax.plot(f, pxx.T, lw=1, picker=1)
if interactivePlot:
def onpick1(event, clab):
thisline = event.artist
n_line = int(str(thisline)[12:-1])
if clab is not None:
print(clab[n_line])
else:
print('channel ' + str(n_line))
onpick = lambda event: onpick1(event, clab)
fig.canvas.mpl_connect('pick_event', onpick)
plt.ylabel('PSD (dB)')
plt.xlabel('Frequency (Hz)')
plt.grid(True, ls='dotted')
return f, pxx, (fig, ax, line)
return f, pxx
# -------------------------------- -------------------------------- --------------------------------
# filtering and filters
# -------------------------------- -------------------------------- --------------------------------
def morlet_filter(data, sfreq, freq_min, freq_max, freq_res=0.5, n_jobs=1, n_cycles='auto'):
"""
morlet filtering with linearly spaced frequency bins, for multi-channel data
:param data: np.ndarray . [channel x time]
:param sfreq: int . sampling frequency
:param freq_min:
:param freq_max:
:param freq_res: frequency resolution
:param n_jobs:
:return: TF of data - complex
"""
from mne.time_frequency import tfr_array_morlet
nchan, nsample = data.shape
data = np.reshape(data, (1, nchan, nsample))
freq_n = int((freq_max - freq_min) / freq_res) + 1
freqs, step = np.linspace(freq_min, freq_max, num=freq_n, retstep=True, endpoint=True)
if n_cycles == 'auto':
n_cycles = freqs / 2.
data_tfr = tfr_array_morlet(data, sfreq, freqs, n_cycles=n_cycles, zero_mean=True,
use_fft=True, decim=1, output='complex', n_jobs=n_jobs, verbose=None)
return data_tfr[0, :, :], freqs
def filtfilt_mirror(b, a, data, axis=-1):
from scipy.signal import filtfilt
axis = data.ndim - 1 if axis == -1 else axis
n_sample = data.shape[axis]
data_flip_neg = -np.flip(data, axis=axis)
data_mirror = np.concatenate((data_flip_neg, data, data_flip_neg), axis=axis)
data_filt = filtfilt(b, a, data_mirror, axis=axis)
indices = {axis: np.arange(n_sample, 2*n_sample, dtype='int')}
ix = tuple(indices.get(dim, slice(None)) for dim in range(data_filt.ndim))
data_filt_cut = data_filt[ix]
return data_filt_cut