-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Add option to store and return TFR taper weights #12910
base: main
Are you sure you want to change the base?
Conversation
@@ -302,12 +306,15 @@ def _make_dpss( | |||
real_offset = Wk.mean() | |||
Wk -= real_offset | |||
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) | |||
Ck = np.sqrt(conc[m]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This I am somewhat unsure on. The existing implementation is to just use conc
as-is, however in the MNE-Connectivity implementation that sqrt is taken: https://github.com/mne-tools/mne-connectivity/blob/97147a57eefb36a5c9680e539fdc6343a1183f20/mne_connectivity/spectral/time.py#L825
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also somewhat confused about the design of the mne-python/mne/time_frequency/tfr.py Lines 285 to 315 in 82fc2f7
It is looping over tapers, and then over frequencies. However, the Would it not be more efficient to only loop over frequencies and take advantage of the fact that this will also return information for each taper? |
# shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) | ||
if method == "morlet": | ||
method_kw.setdefault("zero_mean", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to this PR, but it can be removed.
I also have a question regarding testing: for the I/O tests, we're reading Apart from this there are still some tests I need to expand. |
# always store weights for per-taper outputs | ||
if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: | ||
method_kw["return_weights"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hesitate to blindly overwrite what the user might have put into their method_kw
dict, so I was going to suggest using .setdefault
here. But then I wondered, is there ever a case where the user would sensibly want to pass method_kw=dict(return_weights=False, ...)
? I'm guessing not, since when instantiating the TFR class object, the user isn't getting direct access to the return value of the method anyway. WDYT @tsbinns ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this was my line of thought as well. Also, allowing the user to control this would mean extra logic needs to be put in place when unpacking the tfr values (i.e., whether we need to separate the tfr from the weights). I think just forcing this to True
simplifies things and would not affect the user at all.
@@ -302,12 +306,15 @@ def _make_dpss( | |||
real_offset = Wk.mean() | |||
Wk -= real_offset | |||
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) | |||
Ck = np.sqrt(conc[m]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @drammock! I will sort out those remaining tests, although I'm in the process of moving at the moment so it might not be for some days. Regarding those issues I came across with TFR multitapers and converting to dataframes / plotting: would you like me to incorporate that into this PR? |
Reference issue (if any)
PR for #12851
What does this implement/fix?
Adds an option to return taper weights for complex and phase outputs of the multitaper method in
tfr_array_multitaper()
, and also ensures taper weights are stored inTFR
objects.Additional information
When working on this, I discovered a couple of other issues with the per-taper TFR implementations (#12851 (comment)), including the fact that the
TFR
object plotting methods andto_data_frame
methods do not account for a taper dimension, leading to errors. Wasn't sure if people want me to also address these here or in a separate PR.