diff --git a/eegnb/analysis/utils.py b/eegnb/analysis/utils.py index 119500d8..4a0fa53d 100644 --- a/eegnb/analysis/utils.py +++ b/eegnb/analysis/utils.py @@ -4,10 +4,10 @@ import logging from collections import OrderedDict from glob import glob -from typing import Union, List, Dict -from collections import Iterable +from typing import Union, List#, Dict +# from collections import Iterable from time import sleep, time -from numpy.core.fromnumeric import std +# from numpy.core.fromnumeric import std import keyboard import os @@ -277,14 +277,16 @@ def plot_conditions( for ch in range(channel_count): for cond, color in zip(conditions.values(), palette): - sns.tsplot( - X[y.isin(cond), ch], - time=times, + sns.lineplot( + data=pd.DataFrame(X[y.isin(cond), ch].T, index=times), + x=times, + y=ch, color=color, n_boot=n_boot, - ci=ci, ax=axes[ch], + errorbar=('ci',ci) ) + axes[ch].set(xlabel='Time (s)', ylabel='Amplitude (uV)', title=epochs.ch_names[channel_order[ch]]) if diff_waveform: diff = np.nanmean(X[y == diff_waveform[1], ch], axis=0) - np.nanmean( @@ -298,11 +300,6 @@ def plot_conditions( x=0, ymin=ylim[0], ymax=ylim[1], color="k", lw=1, label="_nolegend_" ) - axes[0].set_xlabel("Time (s)") - axes[0].set_ylabel("Amplitude (uV)") - axes[-1].set_xlabel("Time (s)") - axes[1].set_ylabel("Amplitude (uV)") - if diff_waveform: legend = ["{} - {}".format(diff_waveform[1], diff_waveform[0])] + list( conditions.keys() diff --git a/examples/visual_cueing/01r__cueing_singlesub_analysis.py b/examples/visual_cueing/01r__cueing_singlesub_analysis.py index 9ba11022..8d367165 100644 --- a/examples/visual_cueing/01r__cueing_singlesub_analysis.py +++ b/examples/visual_cueing/01r__cueing_singlesub_analysis.py @@ -14,7 +14,7 @@ # # Some standard pythonic imports -import os,sys,glob,numpy as np,pandas as pd +import os,numpy as np#,sys,glob,pandas as pd from collections import OrderedDict import warnings warnings.filterwarnings('ignore') @@ -22,7 +22,7 @@ import matplotlib.patches as patches # MNE functions -from mne import Epochs,find_events, concatenate_raws +from mne import Epochs,find_events#, concatenate_raws from mne.time_frequency import tfr_morlet # EEG-Notebooks functions @@ -73,7 +73,7 @@ # One way to analyze the SSVEP is to plot the power spectral density, or PSD. SSVEPs should appear as peaks in power for certain frequencies. We expect clear peaks in the spectral domain at the stimulation frequencies of 30 and 20 Hz. # -raw.plot_psd(); +raw.compute_psd().plot(); # Should see the electrical noise at 60 Hz, and maybe a peak at the red and blue channels between 7-14 Hz (Alpha) @@ -84,8 +84,8 @@ # Most ERP components are composed of lower frequency fluctuations in the EEG signal. Thus, we can filter out all frequencies between 1 and 30 hz in order to increase our ability to detect them. # -raw.filter(1,30, method='iir') -raw.plot_psd(fmin=1, fmax=30); +raw.filter(1,30, method='iir'); +raw.compute_psd(fmin=1, fmax=30).plot(); ################################################################################################### # Epoching diff --git a/examples/visual_ssvep/01r__ssvep_viz.py b/examples/visual_ssvep/01r__ssvep_viz.py index d812681c..974e3e6a 100644 --- a/examples/visual_ssvep/01r__ssvep_viz.py +++ b/examples/visual_ssvep/01r__ssvep_viz.py @@ -26,7 +26,7 @@ # MNE functions from mne import Epochs,find_events -from mne.time_frequency import psd_welch,tfr_morlet +from mne.time_frequency import tfr_morlet # EEG-Notebooks functions from eegnb.analysis.utils import load_data,plot_conditions @@ -88,8 +88,14 @@ # Next, we can compare the PSD of epochs specifically during 20hz and 30hz stimulus presentation f, axs = plt.subplots(2, 1, figsize=(10, 10)) -psd1, freq1 = psd_welch(epochs['30 Hz'], n_fft=1028, n_per_seg=256 * 3, picks='all') -psd2, freq2 = psd_welch(epochs['20 Hz'], n_fft=1028, n_per_seg=256 * 3, picks='all') + +welch_params=dict(method='welch', + n_fft=1028, + n_per_seg=256 * 3, + picks='all') + +psd1, freq1 = epochs['30 Hz'].compute_psd(**welch_params).get_data(return_freqs=True) +psd2, freq2 = epochs['20 Hz'].compute_psd(**welch_params).get_data(return_freqs=True) psd1 = 10 * np.log10(psd1) psd2 = 10 * np.log10(psd2) diff --git a/requirements.txt b/requirements.txt index 568058b3..770bd491 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ scikit-learn>=0.23.2 pandas>=1.1.4 numpy>=1.19.4 mne>=0.20.8 -seaborn==0.9.0 +seaborn>=0.9.0 pyriemann>=0.2.7 jupyter muselsl>=2.0.2