Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH, WIP: Spectrogram mode, minimal version #272

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
149 changes: 134 additions & 15 deletions mne_qt_browser/_pg_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from mne import channel_indices_by_type
from mne.annotations import _sync_onset
from mne.io.pick import _DATA_CH_TYPES_ORDER_DEFAULT, _DATA_CH_TYPES_SPLIT
from mne.time_frequency import tfr_array_morlet
from mne.utils import _check_option, _to_rgb, get_config, logger, sizeof_fmt, warn
from mne.viz import plot_sensors
from mne.viz._figure import BrowserBase
Expand All @@ -48,6 +49,7 @@
AxisItem,
FillBetweenItem,
GraphicsView,
ImageItem,
InfiniteLine,
InfLineLabel,
LinearRegionItem,
Expand Down Expand Up @@ -295,18 +297,15 @@ def func(self, *args, **kwargs):
return func


class DataTrace(PlotCurveItem):
"""Graphics-Object for single data trace."""
class BaseDataTrace:
"""Base graphics-object for single data trace."""

def __init__(self, main, ch_idx, child_idx=None, parent_trace=None):
super().__init__()
self.weakmain = weakref.ref(main)
self.mne = main.mne
del main

# Set clickable with small area around trace to make clicking easier.
self.setClickable(True, 12)

# Set default z-value to 1 to be before other items in scene
self.setZValue(1)

Expand Down Expand Up @@ -608,6 +607,80 @@ def get_ydata(self):
return self.yData + self.ypos


class DataTrace(BaseDataTrace, PlotCurveItem):
"""Graphics-Object for single line data trace."""

def __init__(self, main, ch_idx, child_idx=None, parent_trace=None):
super().__init__(main, ch_idx, child_idx=child_idx, parent_trace=parent_trace)

# Set clickable with small area around trace to make clicking easier.
self.setClickable(True, 12)


class ImageTrace(BaseDataTrace, ImageItem):
"""Graphics-Object for single line data trace."""

def __init__(self, main, ch_idx, child_idx=None, parent_trace=None):
super().__init__(main, ch_idx, child_idx=child_idx, parent_trace=parent_trace)
self.setColorMap(self.mne.cmap)

@propagate_to_children
def update_color(self):
"""Update the color of the trace."""
pass

@propagate_to_children
def update_scale(self): # noqa: D102
transform = QTransform()
tmin, tmax = self.mne.times[0], self.mne.times[-1]
transform.scale((tmax - tmin) / self.mne.times.size, 0.9 / self.mne.freqs.size)
self.setTransform(transform)

self.setLevels([0, self.mne.vmax / self.mne.scale_factor])

@propagate_to_children
def update_data(self):
"""Update data (fetch data from self.mne according to self.ch_idx)."""
if self.mne.data_precomputed:
data = self.mne.data[self.order_idx]
else:
data = self.mne.data[self.range_idx]
times = self.mne.times

# Get decim-specific time if enabled
if self.mne.decim != 1:
times = times[:: self.mne.decim_data[self.range_idx]]
data = data[..., :: self.mne.decim_data[self.range_idx]]

# For multiple color traces with epochs
# replace values from other colors with NaN.
if self.mne.is_epochs:
data = np.copy(data)
check_color = self.mne.epoch_color_ref[self.ch_idx, self.mne.epoch_idx]
bool_ixs = np.invert(np.equal(self.color, check_color).all(axis=1))
starts = self.mne.boundary_times[self.mne.epoch_idx][bool_ixs]
stops = self.mne.boundary_times[self.mne.epoch_idx + 1][bool_ixs]

for start, stop in zip(starts, stops):
data[np.logical_and(start <= times, times <= stop)] = np.nan

assert times.shape[-1] == data.shape[-1]

tfr_data = tfr_array_morlet(
data[None, None],
self.mne.info["sfreq"],
freqs=self.mne.freqs,
n_cycles=self.mne.n_cycles,
output="power",
)[0][0]
# tfr_data = rescale(tfr_data, times, (None, None), mode='zlogratio')
self.setImage(
tfr_data[::-1].T, levels=[0, self.mne.vmax / self.mne.scale_factor]
)

self.setPos(times[0], self.range_idx + self.mne.ch_start + 0.5)


class TimeAxis(AxisItem):
"""The X-Axis displaying the time."""

Expand Down Expand Up @@ -3364,6 +3437,12 @@ def __init__(self, **kwargs):
self.mne.butterfly_type_order = [
tp for tp in DATA_CH_TYPES_ORDER if tp in self.mne.ch_types
]
# Spectogram
self.mne.spectrogram = False
self.mne.freqs = np.arange(5, np.min([250, self.mne.info["sfreq"] / 4]), 5)
self.mne.n_cycles = self.mne.freqs / 2
self.mne.cmap = "CET-L18"
self.mne.vmax = 0.2
if self.mne.is_epochs:
# Stores parameters for epochs
self.mne.epoch_dur = np.diff(self.mne.boundary_times[:2])[0]
Expand Down Expand Up @@ -3869,6 +3948,11 @@ def __init__(self, **kwargs):
"slot": [self._toggle_events],
"description": ["Toggle Events visible"],
},
"f": {
"qt_key": Qt.Key_F,
"slot": [self._toggle_spectrogram],
"description": ["Toggle frequency spectrogram"],
},
"h": {
"qt_key": Qt.Key_H,
"slot": [self._toggle_epoch_histogram],
Expand Down Expand Up @@ -4161,6 +4245,9 @@ def change_nchan(self, checked=False, *, step):
self.mne.ax_vscroll.update_nchan()
self.mne.plt.setYRange(ymin, ymax, padding=0)

if self.mne.spectrogram:
self._set_spectrogram(self.mne.spectrogram)

if self.mne.fig_settings is not None:
self.mne.fig_settings._update_sensitivity_spinbox_values()

Expand Down Expand Up @@ -4577,7 +4664,7 @@ def _update_data(self):
# Apply clipping
if self.mne.clipping == "clamp":
self.mne.data = np.clip(self.mne.data, -0.5, 0.5)
elif self.mne.clipping is not None:
elif self.mne.clipping is not None and not self.mne.spectrogram:
self.mne.data = self.mne.data.copy()
self.mne.data[
abs(self.mne.data * self.mne.scale_factor) > self.mne.clipping
Expand Down Expand Up @@ -4878,23 +4965,51 @@ def _set_butterfly(self, butterfly):
padding=0,
)

if self.mne.fig_selection is not None:
# Update Selection-Dialog
self.mne.fig_selection._style_butterfly()
# update ypos and color for butterfly-mode
for trace in self.mne.traces:
trace.update_color()
trace.update_ypos()

self._draw_traces()

def _set_spectrogram(self, spectrogram):
self.mne.spectrogram = spectrogram
self._update_data()

# set yscale in case in butterfly
self.mne.ymax = len(self.mne.ch_order) + 1
self.mne.plt.setLimits(yMax=self.mne.ymax)
self.mne.plt.setYRange(
self.mne.ch_start,
self.mne.ch_start + self.mne.n_channels + 1,
padding=0,
)

# Remove traces
for trace in self.mne.traces.copy():
trace.remove()

if self.mne.spectrogram:
# Add image traces
for ch_idx in self.mne.picks:
ImageTrace(self, ch_idx)
for trace in self.mne.traces:
trace.update_scale()
self._set_scalebars_visible(False)
else:
# Add traces
for ch_idx in self.mne.picks:
DataTrace(self, ch_idx)
self._set_scalebars_visible(self.mne.scalebars_visible)

# Set vertical scrollbar visible
self.mne.ax_vscroll.setVisible(
not butterfly or self.mne.fig_selection is not None
not self.mne.butterfly or self.mne.fig_selection is not None
)

# update overview-bar
self.mne.overview_bar.update_viewrange()

# update ypos and color for butterfly-mode
for trace in self.mne.traces:
trace.update_color()
trace.update_ypos()

self._draw_traces()

self._update_ch_spinbox_values()
Expand All @@ -4903,6 +5018,10 @@ def _toggle_butterfly(self):
if self.mne.instance_type != "ica":
self._set_butterfly(not self.mne.butterfly)

def _toggle_spectrogram(self):
if self.mne.instance_type != "ica":
self._set_spectrogram(not self.mne.spectrogram)

def _toggle_dc(self):
self.mne.remove_dc = not self.mne.remove_dc
self._redraw()
Expand Down