-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradient_analysis_monkey.py
1667 lines (1316 loc) · 77.2 KB
/
gradient_analysis_monkey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 2 16:21:51 2022
@author: saad
Here we compute gradient of outputs with respect to input. This gradient sort
of tells us the 'spikes/R*/rod' i.e. how many spikes will be generated by change
in the input (R*/rod/sec). So sort of gives us how sensitive a particular RGC
is to changes in inputs. By taking the derivative of this derivative, we can estimate
in what direction to change the input that would result in higher firing, or
in what direction to change the output that will result in lower firing rate.
"""
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.regularizers import l2
import multiprocessing
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none'
import os
import h5py
from model.load_savedModel import load
from model.data_handler import load_h5Dataset, prepare_data_cnn2d, prepare_data_pr_cnn2d, unroll_data
from model.performance import getModelParams, model_evaluate_new,paramsToName, get_weightsDict, get_weightsOfLayer
from model import utils_si
from model.models import modelFileName
from model.featureMaps import spatRF2DFit, get_strf, decompose
import model.gradient_tools
from pyret.filtertools import sta, decompose
import gc
from collections import namedtuple
Exptdata = namedtuple('Exptdata', ['X', 'y'])
Exptdata_spikes = namedtuple('Exptdata_spikes',['X','y','spikes'])
import time
import seaborn
import pandas as pd
from tqdm import tqdm
# Enable memory growth
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
print(e)
# Define experiment details
data_pers = 'kiersten' #'kiersten'
expDate = 'monkey01'
subFold = 'gradient_analysis'
fname_stas = '/home/saad/postdoc_db/analyses/data_kiersten/monkey01/db_files/datasets/monkey01_STAs_allLightLevels_8ms_Rstar.h5'
path_dataset = '/home/saad/postdoc_db/analyses/data_kiersten/monkey01/gradient_analysis/datasets/'
dataset_model = 'scot-3-30-Rstar'
path_save = '/home/saad/postdoc_db/analyses/data_kiersten/monkey01/gradient_analysis/'
path_grads = '/home/saad/postdoc_db/analyses/data_kiersten/monkey01/gradient_analysis/gradients'
path_mdl = '/home/saad/data2/analyses/data_kiersten/monkey01/ICLR2023' # CNS
mdl_subFold = '' #'LayerNorm_MultiAxis'
mdl_names = ('CNN_2D_NORM','PRFR_CNN2D_RODS') #'CNN_2D_NORM' #'PRFR_CNN2D_RODS'
paramName_mdl = {}
paramName_mdl['PRFR_CNN2D_RODS'] = 'U-37_P-180_T-120_C1-08-09_C2-16-07_C3-18-05_BN-1_MP-1_LR-0.001_TRSAMPS-040_TR-01'
paramName_mdl['CNN_2D_NORM'] = 'U-37_T-120_C1-08-09_C2-16-07_C3-18-05_BN-1_MP-1_LR-0.001_TRSAMPS-040_TR-01'
# Load models into RAM
mdl_dict = {}
select_mdl = mdl_names[0]
for select_mdl in mdl_names:
fold_mdl = os.path.join(path_mdl,dataset_model,mdl_subFold,select_mdl,paramName_mdl[select_mdl])
fname_performanceFile = os.path.join(fold_mdl,'performance',expDate+'_'+paramName_mdl[select_mdl]+'.h5')
# Load model
f = h5py.File(fname_performanceFile,'r')
perf_model = {}
for key in f['model_performance'].keys():
perf_model[key] = np.array(f['model_performance'][key])
rgb = utils_si.h5_tostring(f['uname_selectedUnits'])
perf_model['uname_selectedUnits'] = rgb
f.close
idx_bestEpoch = np.nanargmax(perf_model['fev_medianUnits_allEpochs'])
# plt.plot(perf_model['fev_medianUnits_allEpochs'])
mdl = load(os.path.join(fold_mdl,paramName_mdl[select_mdl]))
fname_bestWeight = 'weights_'+paramName_mdl[select_mdl]+'_epoch-%03d' % (idx_bestEpoch+1)
try:
mdl.load_weights(os.path.join(fold_mdl,fname_bestWeight))
except:
mdl.load_weights(os.path.join(fold_mdl,fname_bestWeight+'.h5'))
weights_dict = get_weightsDict(mdl)
mdl_dict[select_mdl] = mdl
# %% Load all the datasets on which the model is to be evaluated and for which we have to compute gradients
nsamps_dur = -1 # amount of data to load. In minutes
dataset_eval = ('scot-30-Rstar','scot-3-Rstar','scot-0.3-Rstar')
data_alldsets = {}
d = dataset_eval[0]
for d in dataset_eval:
data_alldsets[d] = {}
name_datasetFile = expDate+'_dataset_train_val_test_'+d+'.h5'
fname_data_train_val_test = os.path.join(path_dataset,name_datasetFile)
data_train,data_val,_,_,dataset_rr,parameters,_ = load_h5Dataset(fname_data_train_val_test,nsamps_train=nsamps_dur)
# Load model information that we need to arrange the data
params_model = getModelParams(os.path.split(fname_performanceFile)[-1])
temporal_width = params_model['T']
pr_temporal_width = params_model['P']
samp_interval = 1 # In these datasets stim is upsampled. So just downsample it.
nsamps = np.floor(data_train.X.shape[0]/samp_interval).astype('int') # 60,000
assert(nsamps<=data_train.X.shape[0])
idx_samps = np.arange(0,nsamps*samp_interval,samp_interval) # this is the index of data that we will extract
data_train = Exptdata_spikes(data_train.X[idx_samps],data_train.y[idx_samps],data_train.spikes[idx_samps])
data_train = prepare_data_cnn2d(data_train,pr_temporal_width,np.arange(data_train.y.shape[1]))
data = data_train
data_alldsets[d]['raw'] = data
data_alldsets[d]['idx_samps'] = rgb
del data_train
data_alldsets['spat_dims'] = (data.X.shape[-2],data.X.shape[-1])
data_alldsets['temporal_dim'] = data.X.shape[1]
del data
# %% Extract performance for each model at each dataset
path_dataset = '/home/saad/postdoc_db/analyses/data_kiersten/monkey01/gradient_analysis/datasets/'
correctMedian = True
perf_datasets = {}
for select_mdl in mdl_names:
perf_datasets[select_mdl] = {}
for d in dataset_eval:
perf_datasets[select_mdl][d] = {}
name_datasetFile = expDate+'_dataset_train_val_test_'+d+'.h5'
fname_data_train_val_test = os.path.join(path_dataset,name_datasetFile)
_,data_val,_,_,dataset_rr,_,resp_orig = load_h5Dataset(fname_data_train_val_test,LOAD_TR=False)
resp_orig = resp_orig['train']
resp_orig[resp_orig==0] = np.nan
# Load model information that we need to arrange the data
fold_mdl = os.path.join(path_mdl,dataset_model,mdl_subFold,select_mdl,paramName_mdl[select_mdl])
fname_performanceFile = os.path.join(fold_mdl,'performance',expDate+'_'+paramName_mdl[select_mdl]+'.h5')
params_model = getModelParams(os.path.split(fname_performanceFile)[-1])
temporal_width = params_model['T']
pr_temporal_width = params_model['P']
# Arrange data as per model inputs
if select_mdl[:6]=='CNN_2D':
obs_rate_allStimTrials_d1 = dataset_rr['stim_0']['val'][:,temporal_width:,:]
data_val = prepare_data_cnn2d(data_val,temporal_width,np.arange(data_val.y.shape[1]))
elif select_mdl[:8]=='PR_CNN2D' or select_mdl[:10]=='PRFR_CNN2D' or select_mdl[:8]=='BP_CNN2D':
obs_rate_allStimTrials_d1 = dataset_rr['stim_0']['val'][:,pr_temporal_width:,:]
data_val = prepare_data_cnn2d(data_val,pr_temporal_width,np.arange(data_val.y.shape[1]))
obs_rate = data_val.y
if correctMedian==True:
fname_data_train_val_test_training = os.path.join(path_mdl,'datasets',('monkey01'+'_dataset_train_val_test_'+dataset_model+'.h5'))
_,_,_,data_quality,_,_,resp_med_d1 = load_h5Dataset(fname_data_train_val_test_training)
resp_med_d1 = np.nanmedian(resp_med_d1['train'],axis=0)
resp_med_d2 = np.nanmedian(resp_orig,axis=0)
resp_mulFac = resp_med_d2/resp_med_d1;
obs_rate_allStimTrials_d1 = obs_rate_allStimTrials_d1/resp_mulFac
obs_rate = obs_rate/resp_mulFac
pred_rate = mdl_dict[select_mdl].predict(data_val.X,batch_size = 100)
fev_d1_allUnits, _, predCorr_d1_allUnits, _ = model_evaluate_new(obs_rate_allStimTrials_d1,pred_rate,0,RR_ONLY=False,lag = 0)
perf_datasets[select_mdl][d]['fev_allUnits'] = fev_d1_allUnits
perf_datasets[select_mdl][d]['corr_allUnits'] = predCorr_d1_allUnits
_ = gc.collect()
# %% Get index of the top common units across all models to be analyzed
"""
This part is useful to select units with top performance. Because then I can
extract gradients in parts and not for all units at once.
The main variables to use from this cell are:
- fev_unitsToExtract
- uname_unitsToExtract
- idx_unitsToExtract
- n_units
"""
n_units = 37
fev_stack = np.zeros(len(perf_model['uname_selectedUnits']))
idx_fev_stack = np.zeros(len(perf_model['uname_selectedUnits'] ))
for select_mdl in mdl_names:
fold_mdl = os.path.join(path_mdl,dataset_model,mdl_subFold,select_mdl,paramName_mdl[select_mdl])
fname_performanceFile = os.path.join(fold_mdl,'performance',expDate+'_'+paramName_mdl[select_mdl]+'.h5')
f = h5py.File(fname_performanceFile,'r')
uname_all_inData = np.array(f['uname_selectedUnits'],dtype='bytes')
uname_all_inData = np.asarray(list(model.utils_si.h5_tostring(uname_all_inData)))
f.close()
# d = dataset_eval[1]
for d in dataset_eval:
fev_allUnits_bestEpoch = perf_datasets[select_mdl][d]['fev_allUnits']
idx_fev_sorted = np.argsort(-1*fev_allUnits_bestEpoch) # descending order
fev_stack = np.vstack((fev_stack,fev_allUnits_bestEpoch))
idx_fev_stack = np.vstack((idx_fev_stack,idx_fev_sorted))
fev_stack = fev_stack[1:].T
idx_fev_stack = idx_fev_stack[1:].T
n_search = 37
idx_fev = idx_fev_stack[:n_search]
rgb = np.intersect1d(idx_fev[:,0],idx_fev[:,1]).astype('int32')
idx_unitsToExtract = rgb[:n_units]
# idx_unitsToExtract = np.array([7,9,11,12]) # u-4
# idx_unitsToExtract = np.array([2,3,4,5,8]) # u-5
# idx_unitsToExtract = np.array([10, 14, 15, 16, 17, 19]) # u-6
# idx_unitsToExtract = np.array([13,18,20,23,27,28,32]) #u-7
fev_unitsToExtract = fev_stack[idx_unitsToExtract]
uname_unitsToExtract = uname_all_inData[idx_unitsToExtract]
print(uname_unitsToExtract)
n_units = len(uname_unitsToExtract)
# %% Compute gradients for all the datasets (and models?)
"""
Because extracting gradients require gpu memory, we have to extract gradients
in batches. Each batch is of batch_size. For efficient processing, we first
calculate gradients for each batch, then those gradients are stored in a list.
The list iterates over batches. Then when we have iterated over all the batches
i.e. we have a list the same size as total_batches, we concatenate everything into
a single large matrix.
This section outputs data_alldsets. Structure is:
data_alldsets
----- dataset_name
------ grads_all --> [n_outputUnits,temporal_width,pixels_y,pixels_x,samples]
------ stim_mat --> [x_samples,temporal_width,pixels_y,pixels_x]
Gradients are computed within GradientTape framework. This allows TF to 'record'
relevant operations in the forward pass. Then during backward pass, TF traverses
this list of operations in reverse order to compute gradients.
"""
path_grads = '/mnt/phd/postdoc/analyses/'
temporal_width_grads = 40
select_mdl = 'CNN_2D_NORM' #'PRFR_CNN2D_RODS' #'CNN_2D_NORM'
save_grads = False
mdl_totake = mdl_dict[select_mdl]
tempWidth_inp = mdl_totake.input.shape[1]
weights_dense_orig = mdl_totake.layers[-2].get_weights()
counter_gc = 0
n_units = len(idx_unitsToExtract)
# idx_unitsToExtract = np.arange(n_units)
d = dataset_eval[0]
for d in dataset_eval:
if save_grads==True:
fname_gradsFile = os.path.join(path_grads,'grads_'+select_mdl+'_'+d+'_'+str(nsamps)+'_u-'+str(n_units)+'.h5')
if os.path.exists(fname_gradsFile):
fname_gradsFile = fname_gradsFile[:-3]+'_1.h5'
data = data_alldsets[d]['raw']
if select_mdl == 'CNN_2D_NORM':
data = Exptdata(data.X[:,-temporal_width:,:,:],data.y)
batch_size = 256 # move this outside
else:
batch_size = 256
nsamps = data.X.shape[0]
total_batches = int(np.floor((nsamps/batch_size)))
i = 0
grads_shape = (data.y.shape[-1],None,temporal_width_grads,data.X.shape[2],data.X.shape[3])
stim_shape = (None,)
t_start = time.time()
if save_grads==True:
f_grads = h5py.File(fname_gradsFile,'a')
grp = model.gradient_tools.init_GradDataset(f_grads,select_mdl,d,grads_shape,stim_shape,batchsize=batch_size)
for i in range(0,total_batches):
counter_gc+=1
print (' List: Batch %d of %d'%(i+1,total_batches))
idx_chunk = np.arange(i*batch_size,(i+1)*batch_size)
data_select_X = data.X[idx_chunk][:,-tempWidth_inp:]
stim_chunk = None #np.array(data_select_X).astype('float16')
inp = tf.Variable(data_select_X, dtype=tf.float32, name='input')
grads_chunk_allUnits = np.zeros((len(idx_unitsToExtract),batch_size,temporal_width_grads,data.X.shape[-2],data.X.shape[-1]),dtype='float32')
t_batch_start = time.time()
u=0
for u in range(len(idx_unitsToExtract)):
idx_unitToModel = np.atleast_1d(idx_unitsToExtract[u])
n_out = idx_unitToModel.shape[0]
y = Dense(n_out, kernel_initializer='normal', kernel_regularizer=l2(1e-3))(mdl_totake.layers[-3].output)
outputs = Activation('softplus',dtype='float32',name='new_activation')(y)
mdl_new = Model(mdl_totake.inputs,outputs)
a = weights_dense_orig[0][:,idx_unitToModel]
b = weights_dense_orig[1][idx_unitToModel]
weights_dense_new = [a,b]
mdl_new.layers[-2].set_weights(weights_dense_new)
with tf.GradientTape(persistent=False,watch_accessed_variables=True) as tape:
out = mdl_new(inp,training=False)
grads_chunk = tape.gradient(out, inp)
grads_chunk = grads_chunk[:,-temporal_width_grads:,:,:]
grads_chunk = np.array(grads_chunk)
grads_chunk_allUnits[u] = grads_chunk
if save_grads==True:
model.gradient_tools.append_GradDataset(f_grads,grp,grads_chunk_allUnits,stim_chunk)
if counter_gc == 250:
_ = gc.collect()
counter_gc = 0
t_batch = time.time()-t_batch_start
print(t_batch/60)
t_dur = time.time()-t_start
print(t_dur/60)
if save_grads==True:
grp.create_dataset('idx_data',data=data_alldsets[d]['idx_samps'])
grp.create_dataset('unames',data=uname_unitsToExtract.astype('bytes'))
grp.create_dataset('fev',data=fev_unitsToExtract)
f_grads.close()
# %% STA vs gradient comparisons
datasets_plot = ('scot-3-Rstar',)#'scot-0.3-Rstar',)#'scot-3-Rstar','scot-0.3-Rstar')
mdls_toplot = ('CNN_2D_NORM','PRFR_CNN2D_RODS',) #PRFR_CNN2D_RODS CNN_2D_NORM
USE_SSD = False # location of gradients
if USE_SSD == True:
path_gradFiles = '/home/saad/postdoc_db/analyses/data_kiersten/monkey01/gradient_analysis/gradients/'
else:
path_gradFiles = '/home/saad/data_hdd/analyses/data_kiersten/monkey01/gradient_analysis/gradients/'
# path_gradFiles = '/mnt/phd/postdoc/analyses/data_kiersten/monkey01/gradient_analysis/gradients/'
path_save_fig = os.path.join(path_save,'sta_vs_lsta')
if not os.path.exists(path_save_fig):
os.makedirs(path_save_fig)
frametime = 1#8
temporal_width_grads = 50
temp_window = 50
sig_fac = 1.5
range_tempFilt = np.arange(temporal_width_grads-temp_window,temporal_width_grads)
u_arr = [0]
m = 0
num_samps = len(idx_samps)
n_units = 7 # suffix for the gradients file
u = 0
for u in u_arr: #np.arange(0,len(perf_model['uname_selectedUnits'])):
spatRF_sta = np.zeros((data_alldsets['spat_dims'][0],data_alldsets['spat_dims'][1],len(datasets_plot),len(mdl_names)))
tempRF_sta = np.zeros((range_tempFilt.shape[0],len(datasets_plot),len(mdl_names)))
spatRF_singImg = np.zeros((data_alldsets['spat_dims'][0],data_alldsets['spat_dims'][1],len(datasets_plot),len(mdl_names)))
tempRF_singImg = np.zeros((range_tempFilt.shape[0],len(datasets_plot),len(mdl_names)))
spatRF_gradAvg_acrossImgs = np.zeros((data_alldsets['spat_dims'][0],data_alldsets['spat_dims'][1],len(datasets_plot),len(mdl_names)))
tempRF_gradAvg_acrossImgs = np.zeros((range_tempFilt.shape[0],len(datasets_plot),len(mdl_names)))
spatRF_indiv_avg_acrossImgs = np.zeros((data_alldsets['spat_dims'][0],data_alldsets['spat_dims'][1],len(datasets_plot),len(mdl_names)))
tempRF_indiv_avg_acrossImgs = np.zeros((range_tempFilt.shape[0],len(datasets_plot),len(mdl_names)))
tempRF_indiv = np.zeros((range_tempFilt.shape[0],num_samps,len(datasets_plot),len(mdl_names)))
for m in range(len(mdls_toplot)):
select_mdl = mdls_toplot[m]
ctr_d = -1
d = datasets_plot[0]
for d in datasets_plot:
fname_gradsFile = os.path.join(path_gradFiles,'grads_'+select_mdl+'_'+d+'_'+str(num_samps)+'_u-'+str(n_units)+'.h5')
f_grads = h5py.File(fname_gradsFile,'r')
uname_all_grads = np.array(f_grads[select_mdl][d]['unames'],'bytes')
uname_all_grads = utils_si.h5_tostring(uname_all_grads)
uname = uname_all_grads[u]
print(uname)
select_rgc_dataset = np.where(uname==uname_all_inData)[0][0]
ctr_d+=1
data = data_alldsets[d]['raw']
# Method 1: Compute STA by taking Response Weighted Average of the stimulus (model independent)
f = h5py.File(fname_stas,'r')
spatial_feat = np.array(f[d[:-6]][uname]['spatial_feature'])
temporal_feat = np.array(f[d[:-6]][uname]['temporal_feature'])
f.close()
peaksearch_win = np.arange(temporal_feat.shape[0]-40,temporal_feat.shape[0])
idx_tempPeak = np.argmax(np.abs(temporal_feat[peaksearch_win])) # only check for peak in the final 25 time points.
idx_tempPeak = idx_tempPeak + peaksearch_win[0]
sign = np.sign(temporal_feat[idx_tempPeak])
if sign<0:
spatial_feat = spatial_feat*sign
temporal_feat = temporal_feat*sign
spatRF_sta[:,:,ctr_d,m] = spatial_feat
tempRF_sta[:,ctr_d,m] = temporal_feat[-temp_window:]
tempRF_sta[:,ctr_d,m] = tempRF_sta[:,ctr_d,m]/tempRF_sta[:,ctr_d,m].max()
# Method 2: Compute LSTA from model for just one input sample
select_img = 50 #768 #712
spatRF, tempRF = model.featureMaps.decompose(f_grads[select_mdl][d]['grads'][u,select_img,-temp_window:,:,:])
rf_coords,rf_fit_img,rf_params,_ = spatRF2DFit(spatRF,tempRF=0,sig_fac=sig_fac,rot=True,sta=0,tempRF_sig=False)
mean_rfCent = np.abs(np.nanmean(rf_fit_img))
spatRF_singImg[:,:,ctr_d,m] = spatRF/mean_rfCent
tempRF_singImg[:,ctr_d,m] = tempRF*mean_rfCent
tempRF_singImg[:,ctr_d,m] = tempRF_singImg[:,ctr_d,m]/tempRF_singImg[:,ctr_d,m].max()
f_grads.close()
vmin = np.min((spatRF_singImg.min(),spatRF_indiv_avg_acrossImgs.min()))
vmax = np.max((spatRF_singImg.max(),spatRF_indiv_avg_acrossImgs.max()))
tmin = np.nanmin((tempRF_sta.min(),tempRF_singImg.min()))-0.05
tmax = np.nanmax((tempRF_sta.max(),tempRF_singImg.max()))+0.05
cmap_name = 'gray' #'cool' # cool
temp_axis = np.arange(temp_window)
temp_axis = np.flip(temp_axis*frametime)
n_ticks = 10
ticks_x = np.arange(0,temp_axis.shape[0],5)
ticks_x[0] = 0
ticks_x_labels = temp_axis[ticks_x]
font_tick = 14
font_title = 14
txt_title = 'Train: %s\nTest: %s\n%s'%(dataset_model,d,uname)
plots_idx = np.array([[0,3],[1,4],[2,5]])
fig,axs = plt.subplots(2,len(datasets_plot)*len(mdls_toplot)+1,figsize=(30,15))
axs = np.ravel(axs)
fig.suptitle(txt_title,size=28)
ctr_d = -1
d = dataset_eval[0]
for d in datasets_plot:
ctr_d+=1
idx_p = plots_idx[0,0]
axs[idx_p].set_title('Conventional STA',fontsize=font_title)
axs[idx_p].imshow(spatRF_sta[:,:,ctr_d,m],aspect='auto',cmap=cmap_name)
axs[idx_p].axes.xaxis.set_visible(False)
axs[idx_p].axes.yaxis.set_visible(False)
idx_p = plots_idx[0,1]
axs[idx_p].plot(tempRF_sta[:,ctr_d,m])
axs[idx_p].set_xlabel('Time prior to spike (frames)',size=font_tick)
axs[idx_p].set_xticks(ticks_x)
axs[idx_p].set_xticklabels(ticks_x_labels)
axs[idx_p].set_ylim(tmin,tmax)
axs[idx_p].set_ylabel('R*/rod/sec',size=font_tick)
axs[idx_p].tick_params(axis='both',labelsize=font_tick)
for m in range(len(mdls_toplot)):
select_mdl = mdls_toplot[m]
idx_p = plots_idx[m+1,0]
axs[idx_p].set_title('single sample',fontsize=font_title)
axs[idx_p].imshow(spatRF_singImg[:,:,ctr_d,m],aspect='auto',cmap=cmap_name)#,vmin=-vmax,vmax=-vmin)
axs[idx_p].axes.xaxis.set_visible(False)
axs[idx_p].axes.yaxis.set_visible(False)
idx_p = plots_idx[m+1,1]
txt_subtitle = '%s | %s | FEV = %02d%%'%(select_mdl,d[5:],perf_datasets[select_mdl][d]['fev_allUnits'][select_rgc_dataset]*100)
axs[idx_p].set_title(txt_subtitle,fontsize=font_title)
axs[idx_p].plot(tempRF_singImg[:,ctr_d,m])
axs[idx_p].set_xlabel('Time prior to spike (frames)',size=font_tick)
axs[idx_p].set_xticks(ticks_x)
axs[idx_p].set_ylim(tmin,tmax)
axs[idx_p].set_xticklabels(ticks_x_labels)
axs[idx_p].tick_params(axis='both',labelsize=font_tick)
axs[idx_p].set_ylabel('spikes/R*/rod',size=font_tick)
_ = gc.collect()
# path_save_fig =
# fname_fig = '%s_characterize' %uname
# fname_fig = os.path.join(path_save_fig,fname_fig)
# fig.savefig(fname_fig+'.png',dpi=150)
# fig.savefig(fname_fig+'.svg')
# plt.close(fig)
# %% TEMP RF BINNING
"""
For each cell:
1. Load the gradients
2. Decompose STRF into spatial and temporal
3. Find the peaks (gain)
4. Bin the temporal RFs by their gain and calc average temporal RF per bin
5. Get list of idx of movies within each bin
6. Perform rev corr on data using movies within each bin and get temporal RF per bin
7.
Create h5 file with following structure
Model
|--- LightLevel
| ----- unit
| ----- tempRF_grads_binned
| ----- gain
.
.
"""
path_save_fig = os.path.join(path_save,'STRFs')
if not os.path.exists(path_save_fig):
os.makedirs(path_save_fig)
SAVE_FIGS = False
select_mdl = 'PRFR_CNN2D_RODS' #('PRFR_CNN2D_RODS','CNN_2D_NORM)
select_lightLevel = 'scot-30-Rstar' #
select_lightLevel_sta = select_lightLevel
n_units = 7 # corresponds to suffix of the grad file
USE_SSD = False # Location of gradient file
nbins = 10 # number of bins to group temporal RF gains in
ONLY_LARGEGRADS = False # True will only select gradients above a specific threshold
dsFac = 4 # Downsampling factor. Because origninal stuff was upsampled by 4
temp_window = 40
temp_window_ds = int(temp_window/dsFac)
sig_fac = 1.5 # For spat RF std
rfExtractNPixs = 10 # Edge size in pixels of window around RF center
timeBin = 8
num_samps_toload = 400000 #149000 #392000 #149000 # Note this is from the begining. Will have to provide indices if start offset
batch_size = 20000
if batch_size<num_samps_toload:
total_batches = int(np.ceil((num_samps_toload/batch_size)))
idx_batchStart = np.linspace(0,num_samps_toload,total_batches,dtype='int32')
else:
idx_batchStart = np.array([0,num_samps_toload])
total_batches=2
labels_rf_params = ['rfSize','rfAngle','spatloc','cent_x','cent_y','polarity','gain','biphasic','t_zero','t_trough','t_zero_peakTrough','amp_trough','t_peak']
binning_param = 'gain'
# u_arr = np.arange(0,20) #np.arange(20,len(perf_model['uname_selectedUnits']))
u_arr = np.arange(n_units)
gradFile_suffix = '_u-%d'%(n_units)
num_samps = len(idx_samps)
if USE_SSD == True:
path_gradFiles = '/home/saad/postdoc_db/analyses/data_kiersten/monkey01/gradient_analysis/gradients/'
else:
path_gradFiles = '/home/saad/data_hdd/analyses/data_kiersten/monkey01/gradient_analysis/gradients/'
path_gradFiles = '/mnt/phd/postdoc/analyses/data_kiersten/monkey01/gradient_analysis/gradients/'
fname_gradsFile = 'grads_'+select_mdl+'_'+select_lightLevel+'_'+str(num_samps)+gradFile_suffix+'.h5' #393229 #149940.
fname_gradsFile = os.path.join(path_gradFiles,fname_gradsFile)
print(fname_gradsFile)
f_grads = h5py.File(fname_gradsFile,'r')
u = 0#u_arr[0]
for u in u_arr:
uname_all_grads = np.array(f_grads[select_mdl][select_lightLevel]['unames'],'bytes')
uname_all_grads = utils_si.h5_tostring(uname_all_grads)
uname = uname_all_grads[u]
select_rgc_dataset = np.where(uname==uname_all_inData)[0][0]
print(uname)
idx_sampsInFullMat = idx_samps[:num_samps_toload] #grads_dict['CNN_2D_NORM']['scot-3-Rstar']['idx_samps']
# idx_sampsInFullMat = idx_sampsInFullMat+40
#---- Load the pre-calculated STA
f_stas = h5py.File(fname_stas,'r')
spatRF_fullSTA = np.array(f_stas[select_lightLevel[:-6]][uname]['spatial_feature'])
tempRF_fullSTA = np.array(f_stas[select_lightLevel[:-6]][uname]['temporal_feature'])
f_stas.close()
peaksearch_win = np.arange(tempRF_fullSTA.shape[0]-60,tempRF_fullSTA.shape[0])
idx_tempPeak = np.argmax(np.abs(tempRF_fullSTA[peaksearch_win])) # only check for peak in the final 25 time points.
idx_tempPeak = idx_tempPeak + peaksearch_win[0]
sign = np.sign(tempRF_fullSTA[idx_tempPeak])
if sign<0:
spatRF_fullSTA = spatRF_fullSTA*sign
tempRF_fullSTA = tempRF_fullSTA*sign
tempRF_fullSTA = tempRF_fullSTA[-temp_window:]
tempRF_fullSTA = tempRF_fullSTA/tempRF_fullSTA.max()
idx_tempPeak = -1*(temp_window - np.argmax(np.abs(tempRF_fullSTA)))
rf_coords,rf_fit_img,rf_params,_ = model.featureMaps.spatRF2DFit(spatRF_fullSTA,tempRF=0,sig_fac=sig_fac,rot=True,sta=0,tempRF_sig=False)
RF_midpoint_x = rf_params['x0']
RF_midpoint_y = rf_params['y0']
rfExtractIdx_x = (np.max((round(RF_midpoint_x-0.5*rfExtractNPixs),0)),np.min((round(RF_midpoint_x+0.5*rfExtractNPixs),spatRF_fullSTA.shape[1]-1)))
rfExtractIdx_y = (np.max((round(RF_midpoint_y-0.5*rfExtractNPixs),0)),np.min((round(RF_midpoint_y+0.5*rfExtractNPixs),spatRF_fullSTA.shape[0]-1)))
spat_dims = np.array([rfExtractNPixs,rfExtractNPixs])
# ---- load gradients
# grads_all = np.zeros((num_samps_toload,temp_window,spat_dims[0],spat_dims[1]),dtype='float16')
spatRF_grand = np.zeros((num_samps_toload,spat_dims[0],spat_dims[1])) # [imgs,y,x]
tempRF_grand = np.zeros((num_samps_toload,temp_window)) # [imgs,time,lightlevels]
rf_params_grand = np.zeros((num_samps_toload,len(labels_rf_params)),dtype='float64') #[img,10 = [polarity,euclidean,theta,amp,biphasic,t_zero,t_peak,t_trough,t_zero_peakTrough,amp_trough] sigma is the width of gaussian
rf_coords_grand = np.zeros((1000,2,num_samps_toload),dtype='float32') # [points,[x,y],imgs]
rf_coords_grand[:] = np.nan
batch=0
for batch in range(total_batches-1):
t_start = time.time()
print('Batch %d of %d'%(batch+1,total_batches-1))
idx_chunk = np.arange(idx_batchStart[batch],idx_batchStart[batch+1])
# load gradient chunk
grads_chunk = f_grads[select_mdl][select_lightLevel]['grads'][u,idx_chunk,-temp_window:,rfExtractIdx_y[0]:rfExtractIdx_y[1],rfExtractIdx_x[0]:rfExtractIdx_x[1]]
# Estimate spatial RF. Need this to eventually extract temporal component
spatRF_chunk = grads_chunk[:,idx_tempPeak-1,:,:] # spatial RF as slice
spatRF_chunk_flatten = spatRF_chunk.reshape(spatRF_chunk.shape[0],-1)
cent_idx_min_max = np.array([np.argmin(spatRF_chunk_flatten,axis=1),np.argmax(spatRF_chunk_flatten,axis=1)])
min_max_spatRF = np.argmax(np.abs([np.min(spatRF_chunk_flatten,axis=1),np.max(spatRF_chunk_flatten,axis=1)]),axis=0)
cent_idx = np.zeros(spatRF_chunk_flatten.shape[0])
cent_idx[min_max_spatRF==1]=cent_idx_min_max[1,min_max_spatRF==1]
cent_idx[min_max_spatRF==0]=cent_idx_min_max[0,min_max_spatRF==0]
cent_idx = cent_idx.astype(int)
rgb = grads_chunk[:,-temp_window:,:,:]
rgb = rgb.reshape(rgb.shape[0],rgb.shape[1],-1)
tempRF_chunk = rgb[np.arange(rgb.shape[0]),:,cent_idx]
sign = np.sign(tempRF_chunk[:,idx_tempPeak-1])
if np.sum(sign<0)>0:
tempRF_chunk[sign<0,:] = tempRF_chunk[sign<0,:]*sign[sign<0][:,None] # Make sure temporal RF is positive and reflect negative peaks in spatial RF
# normalize spatial RF by unit mean and reflect any gain changes purely in temporal part
mean_rfCent = np.nanmean(np.abs(spatRF_chunk_flatten),axis=-1)
spatRF_chunk = spatRF_chunk/mean_rfCent[:,None,None]
tempRF_chunk = tempRF_chunk*mean_rfCent[:,None]
rf_params_chunk = np.zeros((tempRF_chunk.shape[0],len(labels_rf_params)));rf_params_chunk[:] = np.nan
rf_params_chunk[:,6] = np.nanmax(tempRF_chunk,axis=1)
rf_params_chunk[:,11] = np.nanmin(tempRF_chunk,axis=1)
rf_params_chunk[:,12] = np.argmax(tempRF_chunk,axis=1)
spatRF_grand[idx_chunk,:,:] = spatRF_chunk
tempRF_grand[idx_chunk,:] = tempRF_chunk
rf_params_grand[idx_chunk,:] = rf_params_chunk
t_end = time.time()-t_start
print('%0.2f minutes'%(t_end/60))
# Just consider all grads for the time being
bool_largeGrads = np.ones(num_samps_toload,'bool')
print(bool_largeGrads.sum())
_ = gc.collect()
params_plot = ['gain',]
idx_params_select = [p for p in range(len(labels_rf_params)) if labels_rf_params[p] in params_plot]
n_cols = 2;n_rows = int(np.ceil(len(idx_params_select)/n_cols))
plots_idx = np.arange(0,n_rows*n_cols)
txt_title = '%s - properties distribution'%uname
fig2,axs = plt.subplots(n_rows,n_cols,figsize=(20,10))
axs = np.ravel(axs)
fig2.suptitle(txt_title,size=22)
cnt = -1
for param in idx_params_select:
cnt+=1
axs[cnt].hist(rf_params_grand[:,param])
ax_title = '%s'%labels_rf_params[param]
axs[cnt].set_title(ax_title,size=12)
# Plot RF param as function of time
idx_param = [p for p in range(len(labels_rf_params)) if labels_rf_params[p] == 'gain'][0]
rgb = rf_params_grand[:,idx_param].copy()
# rgb = rgb - np.nanmean(rgb)
t = np.arange(0,rf_params_grand.shape[0])*timeBin/1000
idx_datapoints = np.arange(4500,5500)
fontsize=12
fig,axs = plt.subplots(1,1,figsize=(15,5))
axs.plot(t[idx_datapoints],rgb[idx_datapoints]/rgb[idx_datapoints].max())
axs.plot(t[idx_datapoints],data_alldsets[select_lightLevel_sta]['raw'].y[idx_datapoints,select_rgc_dataset]/
data_alldsets[select_lightLevel_sta]['raw'].y[idx_datapoints,select_rgc_dataset].max())
axs.set_xlabel('Time (s)',fontsize=fontsize)
axs.set_ylabel(labels_rf_params[idx_param],fontsize=fontsize)
# ---- Find binning edges
"""
Equal sample binning.
- idx_bin_edges variable just gives us bin edges for totalsamps/nbins
- idx_sorted is the index of data sorted low to high
"""
idx_binning_param = [p for p in range(len(labels_rf_params)) if labels_rf_params[p] == binning_param][0]
data_tobin = rf_params_grand[:,idx_binning_param]
idx_sorted = np.argsort(data_tobin)
a = bool_largeGrads[idx_sorted]
b = np.where(a)[0]
c = idx_sorted[b]
idx_sorted = c
data_sorted = data_tobin[idx_sorted]
idx_bin_edges = np.arange(0,idx_sorted.shape[0],np.floor(idx_sorted.shape[0]/nbins),dtype='int')
if len(idx_bin_edges)<nbins+1:
idx_bin_edges = np.concatenate((idx_bin_edges,np.array([idx_sorted.shape[0]])))
else:
idx_bin_edges[-1] = idx_sorted.shape[0]-1
# plt.plot(data_sorted)
# ---- initialize binned variables
spatRF_grads_binned_grand = np.zeros((spat_dims[0],spat_dims[1],nbins));spatRF_grads_binned_grand[:]=np.nan
tempRF_grads_binned_grand = np.zeros((temp_window,nbins));tempRF_grads_binned_grand[:] = np.nan
rf_params_grads_binned_grand = np.zeros((nbins,*rf_params_grand.shape[1:]),dtype='float64')
rf_coords_grads_binned_grand = np.zeros((629,2,nbins),dtype='float64')
data_real_binned_grand = np.zeros(nbins)
spatRF_real_binned_grand = np.empty((spat_dims[0],spat_dims[1],nbins));spatRF_real_binned_grand[:]=np.nan
tempRF_real_binned_grand = np.empty((temp_window_ds,nbins));tempRF_real_binned_grand[:]=np.nan
rf_params_real_binned_grand = np.zeros((nbins,*rf_params_grand.shape[1:]),dtype='float64')
rf_coords_real_binned_grand = np.zeros((629,2,nbins),dtype='float64')
avgMovie_binned_grand = np.empty((temp_window,spat_dims[0],spat_dims[1],nbins),dtype='float32')
avgMovie_binned_grand[:] = np.nan
sta_grads_binned_grand = np.empty((temp_window,spat_dims[0],spat_dims[1],nbins),dtype='float32')
sta_grads_binned_grand[:] = np.nan
sta_real_binned_grand = np.empty((temp_window_ds,spat_dims[0],spat_dims[1],nbins),dtype='float32')
sta_real_binned_grand[:] = np.nan
temp_win_gradsBin = np.arange(10,30)
# ---- Gradients STRF binning
i = 0
for i in tqdm(range(len(idx_bin_edges)-1),desc='Gradient binning'):
idx_totake = idx_sorted[idx_bin_edges[i]:idx_bin_edges[i+1]]
# metrics for binned grads
rf_params_grads_binned_grand[i,:] = np.nanmean(rf_params_grand[idx_totake,:],axis=0,keepdims=True)
# Select the idx of spat and tempRF in each bin and average them.
spatRF = np.nanmean(spatRF_grand[idx_totake,:,:],axis=0)
tempRF = np.nanmean(tempRF_grand[idx_totake,:],axis=0)
rf_coords,rf_fit_img,rf_params,_ = model.featureMaps.spatRF2DFit(spatRF,tempRF=0,sig_fac=3,rot=True,sta=0,tempRF_sig=False)
mean_rfCent = np.nanmean(np.abs(rf_fit_img))
spatRF = spatRF/mean_rfCent
tempRF = tempRF*mean_rfCent
rf_coords_grads_binned_grand[:,:,i] = rf_coords
spatRF_grads_binned_grand[:,:,i] = spatRF
tempRF_grads_binned_grand[:,i] = tempRF
# sta_grads_binned_grand[:,:,:,i] = np.mean(f_grads[select_mdl][select_lightLevel]['grads'][select_rgc,np.sort(idx_totake),:,:,:],axis=0)
tempRF_grads_binned_grand_norm = tempRF_grads_binned_grand/np.nanmax(tempRF_grads_binned_grand,axis=(0,1),keepdims=True) # should maybe normalize later after removing last bin?
# winSize_x = 20
# winSize_y = 20
# RF_midpoint_x = int(rf_params_grads_binned_grand[int(nbins/2),3])
# RF_midpoint_y = int(rf_params_grads_binned_grand[int(nbins/2),4])
# win_x = (np.max((round(RF_midpoint_x-0.5*winSize_x),0)),np.min((round(RF_midpoint_x+0.5*winSize_x),spatRF.shape[1]-1)))
# win_y = (np.max((round(RF_midpoint_y-0.5*winSize_y),0)),np.min((round(RF_midpoint_y+0.5*winSize_y),spatRF.shape[0]-1)))
# # plt.imshow(spatRF);plt.plot(rf_coords[:,0],rf_coords[:,1],'r');plt.show()
# vmin = spatRF_grads_binned_grand.min()
# vmax = spatRF_grads_binned_grand.max()
# b=4;plt.imshow(spatRF_grads_binned_grand[:,:,b],cmap='gray',vmin=vmin,vmax=vmax);plt.plot(rf_coords_grads_binned_grand[:,0,b],rf_coords_grads_binned_grand[:,1,b],'b');plt.xlim(win_x);plt.ylim(win_y)
# idx=np.array([0,nbins-1]);plt.plot(rf_coords_grads_binned_grand[:,0,idx],rf_coords_grads_binned_grand[:,1,idx],'b');plt.xlim(win_x);plt.ylim(win_y);ax=plt.gca();ax.set_aspect('equal')
# idx=np.array([2,3,4,5,6,7,8,9]);plt.plot(tempRF_grads_binned_grand_norm[:,idx]);plt.show()
# ---- Real STRF binning
i = 7
for i in tqdm(range(len(idx_bin_edges)-1),desc='Data STA'):
idx_totake = idx_sorted[idx_bin_edges[i]:idx_bin_edges[i+1]]
# idx_totake = np.arange(idx_bin_edges[i],idx_bin_edges[i+1])
stim = data_alldsets[select_lightLevel_sta]['raw'].X[idx_totake,-temp_window:,rfExtractIdx_y[0]:rfExtractIdx_y[1],rfExtractIdx_x[0]:rfExtractIdx_x[1]]#.astype('float64')
spikes_totake = data_alldsets[select_lightLevel_sta]['raw'].spikes[idx_totake,select_rgc_dataset]
resp_totake = data_alldsets[select_lightLevel_sta]['raw'].y[idx_totake,select_rgc_dataset]
print('Num spikes in bin %d: %d'%(i,np.sum(spikes_totake>0)))
avg_stim = np.mean(stim,axis=0)
if np.sum(spikes_totake>0)>200:
# Perform rev corr
stim_ds = stim[:,::dsFac]
sta_data = model.featureMaps.getSTA_spikeTrain_simple(stim_ds,spikes_totake)
scaleFac = np.nanmean(resp_totake)/np.var(stim)
sta_data = sta_data * scaleFac
idx_tempPeak_ds = int(idx_tempPeak/dsFac)
spatRF = sta_data[idx_tempPeak_ds,:,:] # slice for SpatRF. To then extract temporal RF
try:
cent_idx_min_max = np.array([np.unravel_index(spatRF.argmin(), spatRF.shape),np.unravel_index(spatRF.argmax(), spatRF.shape)])
min_max_spatRF = np.argmax(np.abs([spatRF.min(),spatRF.max()]))
cent_idx = cent_idx_min_max[min_max_spatRF]
tempRF = sta_data[:,cent_idx[0],cent_idx[1]]
sign = np.sign(tempRF[idx_tempPeak_ds])
if sign<0:
tempRF = tempRF*sign
# Normalize spatRF by unit mean to reflect any variations in gain
# purely in the temporal part
mean_rfCent = np.nanmean(np.abs(spatRF))
spatRF = spatRF/mean_rfCent
tempRF = tempRF*mean_rfCent
except:
tempRF = np.zeros(sta_data.shape[0]);tempRF[:] = np.nan
if np.sum(np.isfinite(spatRF))>0:
rf_params_real_binned_grand[i,0] = np.sqrt(rf_params['sigma_x']**2+rf_params['sigma_y']**2)*sig_fac*2 # spatial size
rf_params_real_binned_grand[i,1] = 180-rf_params['theta'] # theta
rf_params_real_binned_grand[i,2] = np.sqrt(rf_params['x0']**2 + rf_params['y0']**2) # spatial rf location (distance from origin)
rf_params_real_binned_grand[i,3] = rf_params['x0']
rf_params_real_binned_grand[i,4] = rf_params['y0']
sta_real_binned_grand[:,:,:,i] = sta_data
rf_coords_real_binned_grand[:,:,i] = rf_coords
spatRF_real_binned_grand[:,:,i] = spatRF
tempRF_real_binned_grand[:,i] = tempRF
tempRF_real_binned_grand_norm = tempRF_real_binned_grand/np.nanmax(tempRF_real_binned_grand,axis=(0,1),keepdims=True)
_ = gc.collect()
# vmin = np.nanmin(spatRF_real_binned_grand)
# vmax = np.nanmax(spatRF_real_binned_grand)
# b=0;plt.imshow(spatRF_real_binned_grand[:,:,b],cmap='gray',vmin=vmin,vmax=vmax);plt.plot(rf_coords_real_binned_grand[:,0,b],rf_coords_real_binned_grand[:,1,b],'r');plt.xlim(win_x);plt.ylim(win_y);plt.show()
# idx=np.array([0,nbins-1]);plt.plot(rf_coords_real_binned_grand[:,0,idx],rf_coords_real_binned_grand[:,1,idx],'r');plt.xlim(win_x);plt.ylim(win_y);ax=plt.gca();ax.set_aspect('equal');plt.show()
# idx = np.array([2,3,4,5,6,7,8,9]);plt.plot(tempRF_real_binned_grand_norm[:,idx]);plt.show()
gain_grads_binned = np.max(tempRF_grads_binned_grand_norm,axis=0)
gain_real_binned = np.max(tempRF_real_binned_grand_norm,axis=0)
plt.plot(gain_grads_binned,gain_real_binned,'o');plt.ylabel('real');plt.xlabel('grads');plt.show()
idx = np.array([0,1,2,3,4,5,6,7,8,9])
txt_suptitle = '%s | %s (FEV=%02d%%) | Training: %s | Testing: %s | STA: %s'%(select_mdl,uname,perf_datasets[select_mdl][select_lightLevel]['fev_allUnits'][select_rgc_dataset]*100,dataset_model,select_lightLevel,select_lightLevel_sta)
fig,axs = plt.subplots(1,2,figsize=(20,5))
fig.suptitle(txt_suptitle)
axs = np.ravel(axs)
axs[0].plot(tempRF_grads_binned_grand_norm[::dsFac,idx])
axs[0].set_title('gradients');axs[0].set_xlabel('frames')
axs[1].plot(tempRF_real_binned_grand_norm[:,idx])
axs[1].set_title('data');axs[1].set_xlabel('frames')
dict_perUnit = dict(tempRF_grads_binned_grand_norm=tempRF_grads_binned_grand_norm,
tempRF_real_binned_grand_norm=tempRF_real_binned_grand_norm,
tempRF_grads_binned_grand=tempRF_grads_binned_grand,
tempRF_real_binned_grand=tempRF_real_binned_grand,
gain_grads_binned=gain_grads_binned,
gain_real_binned=gain_real_binned,
fev=perf_datasets[select_mdl][select_lightLevel]['fev_allUnits'][select_rgc_dataset]*100,
uname=uname)
fname_results = os.path.join(path_save,'gain_analysis_ds.h5')
if 'f' in locals():
try:
f.close()
except:
f
# save dict_perUnit to h5 file
with h5py.File(fname_results,'a') as f:
# f = h5py.File(fname_results,'a')
grp_name = '/'+select_mdl+'/'+select_lightLevel+'/'+uname
if grp_name in f:
del f[grp_name]
grp = f.create_group(grp_name)
for key in list(dict_perUnit.keys()):
h = grp.create_dataset(key,data=dict_perUnit[key])
f.close()
# %% Load gain file
fname_gainFile = '/home/saad/postdoc_db/analyses/data_kiersten/monkey01/gradient_analysis/gain_analysis_ds.h5'
f = h5py.File(fname_gainFile,'r')
select_mdl = 'PRFR_CNN2D_RODS' # CNN_2D_NORM # PRFR_CNN2D_RODS
select_lightLevel = 'scot-0.3-Rstar'
uname_gainFile = list(f[select_mdl][select_lightLevel].keys()) #['on_mid_003', 'on_mid_004', 'on_mid_005', 'on_mid_006', 'on_mid_009', 'on_mid_011', 'on_mid_015', 'on_mid_016', 'on_mid_017', 'on_mid_018', 'on_mid_020']
temp_win = 40
temp_win_ds = int(temp_win/4)
nbins = 10
gain_grads_cnn = np.zeros((nbins,len(uname_gainFile)));gain_grads_cnn[:]=np.nan
gain_real_cnn = np.zeros((nbins,len(uname_gainFile)));gain_real_cnn[:]=np.nan
tempRF_grads_cnn = np.zeros((temp_win,nbins,len(uname_gainFile)));tempRF_grads_cnn[:]=np.nan
tempRF_real_cnn = np.zeros((temp_win_ds,nbins,len(uname_gainFile)));tempRF_real_cnn[:]=np.nan
fevs_cnn = np.zeros((len(uname_gainFile)));fevs_cnn[:]=np.nan
gain_grads_pr = np.zeros((nbins,len(uname_gainFile)));gain_grads_cnn[:]=np.nan
gain_real_pr = np.zeros((nbins,len(uname_gainFile)));gain_real_cnn[:]=np.nan
tempRF_grads_pr = np.zeros((temp_win,nbins,len(uname_gainFile)));tempRF_grads_cnn[:]=np.nan
tempRF_real_pr = np.zeros((temp_win_ds,nbins,len(uname_gainFile)));tempRF_real_cnn[:]=np.nan
fevs_pr = np.zeros((len(uname_gainFile)));fevs_pr[:]=np.nan
for u in range(len(uname_gainFile)):
uname = uname_gainFile[u]
gain_grads_cnn[:,u] = np.array(f['CNN_2D_NORM'][select_lightLevel][uname]['gain_grads_binned'])
gain_real_cnn[:,u] = np.array(f['CNN_2D_NORM'][select_lightLevel][uname]['gain_real_binned'])
# tempRF_grads_cnn[:,:,u] = np.array(f['CNN_2D_NORM'][select_lightLevel][uname]['tempRF_grads_binned_grand'][-temp_win:])
# tempRF_real_cnn[:,:,u] = np.array(f['CNN_2D_NORM'][select_lightLevel][uname]['tempRF_real_binned_grand'][-temp_win_ds:])
fevs_cnn[u] = np.array(f['CNN_2D_NORM'][select_lightLevel][uname]['fev'])
gain_grads_pr[:,u] = np.array(f['PRFR_CNN2D_RODS'][select_lightLevel][uname]['gain_grads_binned'])
gain_real_pr[:,u] = np.array(f['PRFR_CNN2D_RODS'][select_lightLevel][uname]['gain_real_binned'])
tempRF_grads_pr[:,:,u] = np.array(f['PRFR_CNN2D_RODS'][select_lightLevel][uname]['tempRF_grads_binned_grand'][-temp_win:])
tempRF_real_pr[:,:,u] = np.array(f['PRFR_CNN2D_RODS'][select_lightLevel][uname]['tempRF_real_binned_grand'][-temp_win_ds:])
fevs_pr[u] = np.array(f['PRFR_CNN2D_RODS'][select_lightLevel][uname]['fev'])
f.close()
binsToTake = np.array([0,1,2,3,4,5,6,7,8,9])
mse_cnn = np.nanmean((gain_grads_cnn[binsToTake]-gain_real_cnn[binsToTake])**2,axis=0)
mse_pr = np.nanmean((gain_grads_pr[binsToTake]-gain_real_pr[binsToTake])**2,axis=0)
idx_fev_CNN_G_PR = fevs_cnn>=fevs_pr
idx_fev_PR_G_CNN = ~idx_fev_CNN_G_PR
max_axis = np.max((mse_cnn.max(),mse_pr.max()))+.02
txt_title = 'Training: %s | Testing: %s | N=%d RGCs'%(dataset_model,select_lightLevel,len(uname_gainFile))
fig,axs = plt.subplots(1,1,figsize=(5,5))
axs = np.ravel(axs)
axs[0].plot(mse_cnn[idx_fev_PR_G_CNN],mse_pr[idx_fev_PR_G_CNN],'ro',label='PR>CNN')
axs[0].plot(mse_cnn[idx_fev_CNN_G_PR],mse_pr[idx_fev_CNN_G_PR],'bo',label='CNN>PR')
axs[0].legend()
axs[0].plot([0,1],[0,1],'--k')
axs[0].set_xlim(0,max_axis)
axs[0].set_ylim(0,max_axis)
axs[0].set_xlabel('MSE_CNN');axs[0].set_ylabel('MSE_PR')
axs[0].set_title(txt_title)
# %% BELOW Sections not relevant for paper
# %% SPAT RF BINNING
"""
For each cell, bin the images by gradient strength / temporal filter strength and see if we can do this with real data
"""