Skip to content

Commit

Permalink
Fix plot conditions (#257)
Browse files Browse the repository at this point in the history
* fixed plot_conditions functoni issues

* small change to viz n170 plotting example

* small change to viz p300 plotting example

* fixed to plotting issue

* modify plot command

* update example files

* fix condition label bug

* fix: set layout engine to fix colorbar error

---------

Co-authored-by: Ore O <[email protected]>
  • Loading branch information
JohnGriffiths and oreHGA authored Mar 7, 2024
1 parent 7cf7774 commit d00c2dd
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 48 deletions.
56 changes: 31 additions & 25 deletions eegnb/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mne.channels import make_standard_montage
from mne.filter import create_filter
from matplotlib import pyplot as plt
from matplotlib import lines as mlines
from scipy import stats
from scipy.signal import lfilter, lfilter_zi

Expand Down Expand Up @@ -251,16 +252,16 @@ def plot_conditions(
channel_order = np.array(channel_order)
else:
channel_order = np.array(range(channel_count))
channel_names = np.array(epochs.ch_names)[channel_order]

if isinstance(conditions, dict):
conditions = OrderedDict(conditions)

if palette is None:
palette = sns.color_palette("hls", len(conditions) + 1)

X = epochs.get_data() * 1e6

X = X[:, channel_order]
dfX = epochs.to_data_frame()
dfX[channel_names] *= 1e6

times = epochs.times
y = pd.Series(epochs.events[:, -1])
Expand All @@ -275,12 +276,13 @@ def plot_conditions(
plot_axes.append(axes[axis_x, axis_y])
axes = plot_axes

for ch in range(channel_count):
for cond, color in zip(conditions.values(), palette):
for ch,ch_name in enumerate(channel_names):
for cond,cond_name, color in zip(conditions.values(),conditions.keys(), palette):
dfXc = dfX[dfX.condition.isin(conditions[cond_name])]
sns.lineplot(
data=pd.DataFrame(X[y.isin(cond), ch].T, index=times),
x=times,
y=ch,
data=dfXc,
x="time",
y=ch_name,
color=color,
n_boot=n_boot,
ax=axes[ch],
Expand All @@ -289,26 +291,30 @@ def plot_conditions(
axes[ch].set(xlabel='Time (s)', ylabel='Amplitude (uV)', title=epochs.ch_names[channel_order[ch]])

if diff_waveform:
diff = np.nanmean(X[y == diff_waveform[1], ch], axis=0) - np.nanmean(
X[y == diff_waveform[0], ch], axis=0
)
dfXc1 = dfX[dfX.condition.isin(conditions[diff_waveform[1]])]
dfXc2 = dfX[dfX.condition.isin(conditions[diff_waveform[0]])]
dfXc1_mn = dfXc1.set_index(['time', 'epoch'])[ch_name].unstack('epoch').mean(axis=1)
dfXc2_mn = dfXc2.set_index(['time', 'epoch'])[ch_name].unstack('epoch').mean(axis=1)
diff = (dfXc1_mn - dfXc2_mn).values
axes[ch].plot(times, diff, color="k", lw=1)

axes[ch].set_title(epochs.ch_names[channel_order[ch]])
axes[ch].set_title(ch_name)
axes[ch].set_ylim(ylim)
axes[ch].axvline(
x=0, ymin=ylim[0], ymax=ylim[1], color="k", lw=1, label="_nolegend_"
)

legs = []
for cond,cond_name,color in zip(conditions.values(),conditions.keys(), palette):
lh = mlines.Line2D([], [], color=color, marker='', ls='-', label=cond_name)
legs.append(lh)
if diff_waveform:
legend = ["{} - {}".format(diff_waveform[1], diff_waveform[0])] + list(
conditions.keys()
)
else:
legend = conditions.keys()
axes[-1].legend(
legend, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0
)
lh = mlines.Line2D([], [], color="k", marker='', ls='-',
label = "{} - {}".format(diff_waveform[1], diff_waveform[0]))
legs.append(lh)

axes[-1].legend(handles=legs,
bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
sns.despine()
plt.tight_layout()

Expand All @@ -328,6 +334,9 @@ def plot_highlight_regions(
Args:
x (array_like): x coordinates
y (array_like): y values of same shape as `x`
Keyword Args:
hue (array_like): values to be plotted as hue based on `hue_thresh`.
Must be of the same shape as `x` and `y`.
Keyword Args:
hue (array_like): values to be plotted as hue based on `hue_thresh`.
Must be of the same shape as `x` and `y`.
Expand Down Expand Up @@ -453,12 +462,9 @@ def check_report(eeg: EEG, n_times: int=60, pause_time=5, thres_std_low=None, th
Usage:
------
from eegnb.devices.eeg import EEG
from eegnb.analysis.utils import check_report
eeg = EEG(device='museS')
check_report(eeg)
standard deviation for a quality recording.
The thres_std_low & thres_std_high values are the
lower and upper bound of accepted
thresholds = {
standard deviation for a quality recording.
thresholds = {
Expand Down
16 changes: 10 additions & 6 deletions examples/visual_cueing/01r__cueing_singlesub_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,15 @@
print('sample drop %: ', (1 - len(epochs.events)/len(events)) * 100)

conditions = OrderedDict()
conditions['LeftCue'] = [1]
conditions['RightCue'] = [2]
#conditions['LeftCue'] = [1]
#conditions['RightCue'] = [2]
conditions['LeftCue'] = ['LeftCue']
conditions['RightCue'] = ['RightCue']
diffwave = ('LeftCue', 'RightCue')

fig, ax = plot_conditions(epochs, conditions=conditions,
ci=97.5, n_boot=1000, title='',
diff_waveform=(1, 2), ylim=(-20,20))
diff_waveform=diffwave, ylim=(-20,20))

###################################################################################################
# Spectrogram
Expand Down Expand Up @@ -242,10 +245,11 @@
print('sample drop %: ', (1 - len(epochs.events)/len(events)) * 100)

conditions = OrderedDict()
conditions['ValidTarget'] = [21,22]
conditions['InvalidTarget'] = [11,12]
conditions['ValidTarget'] = ['ValidTarget_Left', 'ValidTarget_Right']
conditions['InvalidTarget'] = ['InvalidTarget_Left', 'InvalidTarget_Right']
diffwave = ('ValidTarget', 'InvalidTarget')

fig, ax = plot_conditions(epochs, conditions=conditions,
ci=97.5, n_boot=1000, title='',
diff_waveform=(1, 2), ylim=(-20,20))
diff_waveform=diffwave, ylim=(-20,20))

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from mne.time_frequency import tfr_morlet

# EEG-Noteooks functions
from eegnb.analysis.utils import load_data,plot_conditions
from eegnb.analysis.utils import load_data
from eegnb.datasets import fetch_dataset

# sphinx_gallery_thumbnail_number = 1
Expand Down
19 changes: 12 additions & 7 deletions examples/visual_n170/01r__n170_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

# Some standard pythonic imports
import os
from matplotlib import pyplot as plt
from collections import OrderedDict
import warnings
warnings.filterwarnings('ignore')
Expand Down Expand Up @@ -96,16 +97,20 @@
# ----------------------------

conditions = OrderedDict()
conditions['House'] = [1]
conditions['Face'] = [2]
#conditions['House'] = [1]
#conditions['Face'] = [2]
conditions['House'] = ['House']
conditions['Face'] = ['Face']
diffwav = ('Face', 'House')

fig, ax = plot_conditions(epochs, conditions=conditions,
ci=97.5, n_boot=1000, title='',
diff_waveform=None, #(1, 2))
channel_order=[1,0,2,3]) # reordering of epochs.ch_names according to [[0,2],[1,3]] of subplot axes
diff_waveform=diffwav,
channel_order=[1,0,2,3])
# reordering of epochs.ch_names according to [[0,2],[1,3]] of subplot axes

# Manually adjust the ylims
for i in [0,2]: ax[i].set_ylim([-0.5,0.5])
for i in [1,3]: ax[i].set_ylim([-1.5,2.5])

for i in [0,2]: ax[i].set_ylim([-0.5e6,0.5e6])
for i in [1,3]: ax[i].set_ylim([-1.5e6,2.5e6])
plt.tight_layout()

23 changes: 14 additions & 9 deletions examples/visual_p300/01r__p300_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

# Some standard pythonic imports
import os
from matplotlib import pyplot as plt
from collections import OrderedDict
import warnings
warnings.filterwarnings('ignore')
Expand Down Expand Up @@ -80,26 +81,30 @@

# Create an array containing the timestamps and type of each stimulus (i.e. face or house)
events = find_events(raw)
event_id = {'Non-Target': 1, 'Target': 2}
event_id = {'non-target': 1, 'target': 2}
epochs = Epochs(raw, events=events, event_id=event_id,
tmin=-0.1, tmax=0.8, baseline=None,
reject={'eeg': 100e-6}, preload=True,
tmin=-0.1, tmax=0.8, baseline=None, reject={'eeg': 100e-6}, preload=True,
verbose=False, picks=[0,1,2,3])

print('sample drop %: ', (1 - len(epochs.events)/len(events)) * 100)

epochs


###################################################################################################
# Epoch average
# ----------------------------

conditions = OrderedDict()
conditions['Non-target'] = [1]
conditions['Target'] = [2]
conditions['non-target'] = ['non-target']
conditions['target'] = ['target']
diffwav = ["non-target", "target"]

fig, ax = plot_conditions(epochs, conditions=conditions,
ci=97.5, n_boot=1000, title='',
diff_waveform=(1, 2))
channel_order=[1,0,2,3],ylim=[-2E6,2.5E6],
diff_waveform = diffwav)

# Manually adjust the ylims
for i in [0,2]: ax[i].set_ylim([-0.5e6,0.5e6])
for i in [1,3]: ax[i].set_ylim([-1.5e6,2.5e6])

plt.tight_layout()

2 changes: 2 additions & 0 deletions examples/visual_ssvep/01r__ssvep_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@
tfr.plot(picks=[4], baseline=(-0.5, -0.1), mode='logratio',
title='POz - 20 Hz stim');

# Set Layout engine to tight to fix error with using colorbar layout error
plt.figure().set_layout_engine('tight');
plt.tight_layout()

# Once again we can see clear SSVEPs at 30hz and 20hz

0 comments on commit d00c2dd

Please sign in to comment.