-
Notifications
You must be signed in to change notification settings - Fork 0
/
cond_tfr_contcat.py
112 lines (98 loc) · 3.89 KB
/
cond_tfr_contcat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import mne
import numpy as np
from mne.time_frequency import read_tfrs
from os.path import isdir
import pickle
import matplotlib.pyplot as plt
plt.ion()
import matplotlib
font = {'weight' : 'bold',
'size' : 20}
matplotlib.rc('font', **font)
'''
Plot a continuous variable by condition
'''
def cond2vec(exog_names, params, keys_cond):
out_vec = np.zeros(len(exog_names))
for param in params:
out_vec[exog_names.index(keys_cond[param])] = 1
return out_vec
if isdir("/home/jev"):
root_dir = "/home/jev/hdd/sfb/"
elif isdir("/home/jeff"):
root_dir = "/home/jeff/hdd/jeff/sfb/"
proc_dir = root_dir+"proc/"
durs = ["30s", "2m", "5m"]
conds = ["sham","fix","eig"]
osc = "SO"
baseline = "zscore"
use_group = "nogroup"
badsubjs = "all_subj"
cont_var = "Age"
if baseline == "zscore":
vmin, vmax = -.1, .1
elif baseline == "logmean":
vmin, vmax = -.35, .35
cond_keys = {cont_var:cont_var,
cont_var+":C(StimType, Treatment('sham'))[T.eig]":"Eigenfrequency",
cont_var+":C(StimType, Treatment('sham'))[T.fix]":"Fixed frequency",
cont_var+":C(Dur, Treatment('30s'))[T.2m]":"2m",
cont_var+":C(Dur, Treatment('30s'))[T.5m]":"5m",
cont_var+":C(StimType, Treatment('sham'))[T.eig]:C(Dur, Treatment('30s'))[T.2m]":"Eigenfrequency 2m",
cont_var+":C(StimType, Treatment('sham'))[T.fix]:C(Dur, Treatment('30s'))[T.2m]":"Fixed frequency 2m",
cont_var+":C(StimType, Treatment('sham'))[T.eig]:C(Dur, Treatment('30s'))[T.5m]":"Eigenfrequency 5m",
cont_var+":C(StimType, Treatment('sham'))[T.fix]:C(Dur, Treatment('30s'))[T.5m]":"Fixed frequency 5m"
}
keys_cond = {v:k for k,v in cond_keys.items()}
tfr = read_tfrs("{}grand_central_{}-tfr.h5".format(proc_dir, baseline))[0]
tfr_avg = tfr.average()
epo = mne.read_epochs(proc_dir+"grand_central_finfo-epo.fif")
e = epo["OscType=='{}'".format(osc)]
e.resample(tfr.info["sfreq"], n_jobs="cuda")
e.crop(tmin=tfr.times[0], tmax=tfr.times[-1])
if osc == "deltO":
tfr_avg.crop(tmin=-0.75, tmax=0.75)
epo.crop(tmin=-0.75, tmax=0.75)
# calculate global ERP min and max for scaling later on
evo = e.average()
ev_min, ev_max = evo.data.min(), evo.data.max()
# get osc ERP and normalise
evo_data = evo.data
evo_data = (evo_data - ev_min) / (ev_max - ev_min)
evo_data = evo_data*4 + 12
stat_conds = list(cond_keys.keys())
tfr_c = tfr_avg.copy()
dat_shape = tfr_c.data.shape[1:]
infile = "{}main_fits_{}_{}_{}_{}_cont_{}.pickle".format(proc_dir, baseline,
osc, badsubjs,
use_group, cont_var)
with open(infile, "rb") as f:
fits = pickle.load(f)
exog_names = fits["exog_names"]
modfit = fits["fits"]
fig, axes = plt.subplots(3, 3, figsize=(38.4,21.6))
axes = [ax for axe in axes for ax in axe]
for en_idx,en in enumerate(list(cond_keys.keys())):
data = np.zeros((3, len(modfit)))
for mf_idx, mf in enumerate(modfit):
data[0, mf_idx] = mf.params[exog_names.index(en)]
data[1, mf_idx] = mf.tvalues[exog_names.index(en)]
data[2, mf_idx] = mf.pvalues[exog_names.index(en)]
pvals = data[2,].reshape(*dat_shape, order="F")
pvals[np.isnan(pvals)] = 1
mask = pvals<0.05
if "Intercept" in en:
mask = None
dat = data[0,].reshape(*dat_shape, order="F")
dat[np.isnan(dat)] = 0
tfr_c.data[0,] = dat
tfr_c.plot(picks="central", axes=axes[en_idx], colorbar=False,
vmin=vmin, vmax=vmax, cmap="viridis", mask=mask,
mask_style="contour")
axes[en_idx].plot(tfr.times, evo_data[0,],
color="gray", alpha=0.8,
linewidth=10)
axes[en_idx].set_title(cond_keys[en])
fig.suptitle("{} Power by Eigenfrequency".format(osc))
fig.tight_layout()
fig.savefig("../images/lmmtfr_contcat_{}_{}_{}_{}.tif".format(cont_var, osc, badsubjs, use_group))