Skip to content

Commit

Permalink
Update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Oct 22, 2024
1 parent 45c6a0b commit 82fc2f7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 22 deletions.
5 changes: 3 additions & 2 deletions mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,8 @@ def tfr_array_multitaper(
coherence across trials.
return_weights : bool, default False
If True, return the taper weights. Only applies if ``output="complex"``.
If True, return the taper weights. Only applies if ``output='complex'`` or
``'phase'``.
.. versionadded:: 1.9.0
Expand All @@ -528,7 +529,7 @@ def tfr_array_multitaper(
contain the average power and the imaginary values contain the
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
weights : array of shape (n_tapers, n_freqs)
The taper weights. Only returned if ``output="complex"`` and
The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and
``return_weights=True``.
See Also
Expand Down
50 changes: 30 additions & 20 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,9 +1215,6 @@ def __init__(
f'{classname} got unsupported parameter value{_pl(problem)} '
f'{" and ".join(problem)}.'
)
# shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release)
if method == "morlet":
method_kw.setdefault("zero_mean", True)
# check method
valid_methods = ["morlet", "multitaper"]
if isinstance(inst, BaseEpochs):
Expand Down Expand Up @@ -2697,42 +2694,55 @@ def to_data_frame(
"""
# check pandas once here, instead of in each private utils function
pd = _check_pandas_installed() # noqa
# triage for Epoch-derived or unaggregated spectra
from_epo = isinstance(self, EpochsTFR)
unagg_mt = "taper" in self._dims
# arg checking
valid_index_args = ["time", "freq"]
if isinstance(self, EpochsTFR):
if from_epo:
valid_index_args.extend(["epoch", "condition"])
valid_time_formats = ["ms", "timedelta"]
index = _check_pandas_index_arguments(index, valid_index_args)
time_format = _check_time_format(time_format, valid_time_formats)
# get data
picks = _picks_to_idx(self.info, picks, "all", exclude=())
data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True)
axis = self._dims.index("channel")
if not isinstance(self, EpochsTFR):
ch_axis = self._dims.index("channel")
if not from_epo:
data = data[np.newaxis] # add singleton "epochs" axis
axis += 1
n_epochs, n_picks, n_freqs, n_times = data.shape
# reshape to (epochs*freqs*times) x signals
data = np.moveaxis(data, axis, -1)
data = data.reshape(n_epochs * n_freqs * n_times, n_picks)
ch_axis += 1
if not unagg_mt:
data = np.expand_dims(data, -3) # add singleton "tapers" axis
n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape
# reshape to (epochs*tapers*freqs*times) x signals
data = np.moveaxis(data, ch_axis, -1)
data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks)
# prepare extra columns / multiindex
mindex = list()
default_index = list()
times = _convert_times(times, time_format, self.info["meas_date"])
times = np.tile(times, n_epochs * n_freqs)
freqs = np.tile(np.repeat(freqs, n_times), n_epochs)
times = np.tile(times, n_epochs * n_freqs * n_tapers)
freqs = np.tile(np.repeat(freqs, n_times * n_tapers), n_epochs)
mindex.append(("time", times))
mindex.append(("freq", freqs))
if isinstance(self, EpochsTFR):
mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs)))
if from_epo:
mindex.append(
("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers))
)
rev_event_id = {v: k for k, v in self.event_id.items()}
conditions = [rev_event_id[k] for k in self.events[:, 2]]
mindex.append(("condition", np.repeat(conditions, n_times * n_freqs)))
mindex.append(
("condition", np.repeat(conditions, n_times * n_freqs * n_tapers))
)
default_index.extend(["condition", "epoch"])
default_index.extend(["freq", "time"])
if unagg_mt:
name = "taper"
taper_nums = np.tile(np.arange(n_tapers), n_epochs * n_freqs * n_times)
mindex.append((name, taper_nums))
default_index.append(name)
assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:])
# build DataFrame
if isinstance(self, EpochsTFR):
default_index = ["condition", "epoch", "freq", "time"]
else:
default_index = ["freq", "time"]
df = _build_data_frame(
self, data, picks, long_format, mindex, index, default_index=default_index
)
Expand Down

0 comments on commit 82fc2f7

Please sign in to comment.