diff --git a/mne_qt_browser/_pg_figure.py b/mne_qt_browser/_pg_figure.py index d32c49e0..e4ed6666 100644 --- a/mne_qt_browser/_pg_figure.py +++ b/mne_qt_browser/_pg_figure.py @@ -39,6 +39,7 @@ from mne import channel_indices_by_type from mne.annotations import _sync_onset from mne.io.pick import _DATA_CH_TYPES_ORDER_DEFAULT, _DATA_CH_TYPES_SPLIT +from mne.time_frequency import tfr_array_morlet from mne.utils import _check_option, _to_rgb, get_config, logger, sizeof_fmt, warn from mne.viz import plot_sensors from mne.viz._figure import BrowserBase @@ -48,6 +49,7 @@ AxisItem, FillBetweenItem, GraphicsView, + ImageItem, InfiniteLine, InfLineLabel, LinearRegionItem, @@ -295,8 +297,8 @@ def func(self, *args, **kwargs): return func -class DataTrace(PlotCurveItem): - """Graphics-Object for single data trace.""" +class BaseDataTrace: + """Base graphics-object for single data trace.""" def __init__(self, main, ch_idx, child_idx=None, parent_trace=None): super().__init__() @@ -304,9 +306,6 @@ def __init__(self, main, ch_idx, child_idx=None, parent_trace=None): self.mne = main.mne del main - # Set clickable with small area around trace to make clicking easier. - self.setClickable(True, 12) - # Set default z-value to 1 to be before other items in scene self.setZValue(1) @@ -608,6 +607,80 @@ def get_ydata(self): return self.yData + self.ypos +class DataTrace(BaseDataTrace, PlotCurveItem): + """Graphics-Object for single line data trace.""" + + def __init__(self, main, ch_idx, child_idx=None, parent_trace=None): + super().__init__(main, ch_idx, child_idx=child_idx, parent_trace=parent_trace) + + # Set clickable with small area around trace to make clicking easier. + self.setClickable(True, 12) + + +class ImageTrace(BaseDataTrace, ImageItem): + """Graphics-Object for single line data trace.""" + + def __init__(self, main, ch_idx, child_idx=None, parent_trace=None): + super().__init__(main, ch_idx, child_idx=child_idx, parent_trace=parent_trace) + self.setColorMap(self.mne.cmap) + + @propagate_to_children + def update_color(self): + """Update the color of the trace.""" + pass + + @propagate_to_children + def update_scale(self): # noqa: D102 + transform = QTransform() + tmin, tmax = self.mne.times[0], self.mne.times[-1] + transform.scale((tmax - tmin) / self.mne.times.size, 0.9 / self.mne.freqs.size) + self.setTransform(transform) + + self.setLevels([0, self.mne.vmax / self.mne.scale_factor]) + + @propagate_to_children + def update_data(self): + """Update data (fetch data from self.mne according to self.ch_idx).""" + if self.mne.data_precomputed: + data = self.mne.data[self.order_idx] + else: + data = self.mne.data[self.range_idx] + times = self.mne.times + + # Get decim-specific time if enabled + if self.mne.decim != 1: + times = times[:: self.mne.decim_data[self.range_idx]] + data = data[..., :: self.mne.decim_data[self.range_idx]] + + # For multiple color traces with epochs + # replace values from other colors with NaN. + if self.mne.is_epochs: + data = np.copy(data) + check_color = self.mne.epoch_color_ref[self.ch_idx, self.mne.epoch_idx] + bool_ixs = np.invert(np.equal(self.color, check_color).all(axis=1)) + starts = self.mne.boundary_times[self.mne.epoch_idx][bool_ixs] + stops = self.mne.boundary_times[self.mne.epoch_idx + 1][bool_ixs] + + for start, stop in zip(starts, stops): + data[np.logical_and(start <= times, times <= stop)] = np.nan + + assert times.shape[-1] == data.shape[-1] + + tfr_data = tfr_array_morlet( + data[None, None], + self.mne.info["sfreq"], + freqs=self.mne.freqs, + n_cycles=self.mne.n_cycles, + output="power", + )[0][0] + # tfr_data = rescale(tfr_data, times, (None, None), mode='zlogratio') + self.setImage( + tfr_data[::-1].T, levels=[0, self.mne.vmax / self.mne.scale_factor] + ) + + self.setPos(times[0], self.range_idx + self.mne.ch_start + 0.5) + + class TimeAxis(AxisItem): """The X-Axis displaying the time.""" @@ -3364,6 +3437,12 @@ def __init__(self, **kwargs): self.mne.butterfly_type_order = [ tp for tp in DATA_CH_TYPES_ORDER if tp in self.mne.ch_types ] + # Spectogram + self.mne.spectrogram = False + self.mne.freqs = np.arange(5, np.min([250, self.mne.info["sfreq"] / 4]), 5) + self.mne.n_cycles = self.mne.freqs / 2 + self.mne.cmap = "CET-L18" + self.mne.vmax = 0.2 if self.mne.is_epochs: # Stores parameters for epochs self.mne.epoch_dur = np.diff(self.mne.boundary_times[:2])[0] @@ -3869,6 +3948,11 @@ def __init__(self, **kwargs): "slot": [self._toggle_events], "description": ["Toggle Events visible"], }, + "f": { + "qt_key": Qt.Key_F, + "slot": [self._toggle_spectrogram], + "description": ["Toggle frequency spectrogram"], + }, "h": { "qt_key": Qt.Key_H, "slot": [self._toggle_epoch_histogram], @@ -4161,6 +4245,9 @@ def change_nchan(self, checked=False, *, step): self.mne.ax_vscroll.update_nchan() self.mne.plt.setYRange(ymin, ymax, padding=0) + if self.mne.spectrogram: + self._set_spectrogram(self.mne.spectrogram) + if self.mne.fig_settings is not None: self.mne.fig_settings._update_sensitivity_spinbox_values() @@ -4577,7 +4664,7 @@ def _update_data(self): # Apply clipping if self.mne.clipping == "clamp": self.mne.data = np.clip(self.mne.data, -0.5, 0.5) - elif self.mne.clipping is not None: + elif self.mne.clipping is not None and not self.mne.spectrogram: self.mne.data = self.mne.data.copy() self.mne.data[ abs(self.mne.data * self.mne.scale_factor) > self.mne.clipping @@ -4878,23 +4965,51 @@ def _set_butterfly(self, butterfly): padding=0, ) - if self.mne.fig_selection is not None: - # Update Selection-Dialog - self.mne.fig_selection._style_butterfly() + # update ypos and color for butterfly-mode + for trace in self.mne.traces: + trace.update_color() + trace.update_ypos() + + self._draw_traces() + + def _set_spectrogram(self, spectrogram): + self.mne.spectrogram = spectrogram + self._update_data() + + # set yscale in case in butterfly + self.mne.ymax = len(self.mne.ch_order) + 1 + self.mne.plt.setLimits(yMax=self.mne.ymax) + self.mne.plt.setYRange( + self.mne.ch_start, + self.mne.ch_start + self.mne.n_channels + 1, + padding=0, + ) + + # Remove traces + for trace in self.mne.traces.copy(): + trace.remove() + + if self.mne.spectrogram: + # Add image traces + for ch_idx in self.mne.picks: + ImageTrace(self, ch_idx) + for trace in self.mne.traces: + trace.update_scale() + self._set_scalebars_visible(False) + else: + # Add traces + for ch_idx in self.mne.picks: + DataTrace(self, ch_idx) + self._set_scalebars_visible(self.mne.scalebars_visible) # Set vertical scrollbar visible self.mne.ax_vscroll.setVisible( - not butterfly or self.mne.fig_selection is not None + not self.mne.butterfly or self.mne.fig_selection is not None ) # update overview-bar self.mne.overview_bar.update_viewrange() - # update ypos and color for butterfly-mode - for trace in self.mne.traces: - trace.update_color() - trace.update_ypos() - self._draw_traces() self._update_ch_spinbox_values() @@ -4903,6 +5018,10 @@ def _toggle_butterfly(self): if self.mne.instance_type != "ica": self._set_butterfly(not self.mne.butterfly) + def _toggle_spectrogram(self): + if self.mne.instance_type != "ica": + self._set_spectrogram(not self.mne.spectrogram) + def _toggle_dc(self): self.mne.remove_dc = not self.mne.remove_dc self._redraw()