-
Notifications
You must be signed in to change notification settings - Fork 7
/
dat_loader_simple.py
1577 lines (1361 loc) · 56.9 KB
/
dat_loader_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Simplified Data Loading
"""
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
from torch.utils.data.distributed import DistributedSampler
import torch
from torch.nn import functional as F
from pathlib import Path
from _init_stuff import Fpath, Arr, yaml, DF, ForkedPdb
from yacs.config import CfgNode as CN
import pandas as pd
import h5py
import numpy as np
import json
import copy
from typing import Dict
from munch import Munch
from trn_utils import DataWrap
import ast
import pickle
from contrastive_sampling import create_similar_list, create_random_list
from mdl_srl_utils import combine_first_ax
from trn_utils import get_dataloader
torch.multiprocessing.set_sharing_strategy('file_system')
class AnetEntDataset(Dataset):
"""
Dataset class adopted from
https://github.com/facebookresearch/grounded-video-description
/blob/master/misc/dataloader_anet.py#L27
This is basically AE loader.
"""
def __init__(self, cfg: CN, ann_file: Fpath, split_type: str = 'train',
comm: Dict = None):
self.cfg = cfg
# Common stuff that needs to be passed around
if comm is not None:
assert isinstance(comm, (dict, Munch))
self.comm = Munch(comm)
else:
self.comm = Munch()
self.split_type = split_type
self.ann_file = Path(ann_file)
assert self.ann_file.suffix == '.csv'
self.set_args()
self.load_annotations()
# self.create_glove_stuff()
h5_proposal_file = h5py.File(
self.proposal_h5, 'r', driver='core')
self.num_proposals = h5_proposal_file['dets_num'][:]
self.label_proposals = h5_proposal_file['dets_labels'][:]
self.itemgetter = getattr(self, 'simple_item_getter')
self.test_mode = (split_type != 'test')
self.after_init()
def after_init(self):
pass
def set_args(self):
"""
Define the arguments to be used from the cfg
"""
# NOTE: These are changed at extended_config/post_proc_config
dct = self.cfg.ds[f'{self.cfg.ds.exp_setting}']
self.proposal_h5 = Path(dct['proposal_h5'])
self.feature_root = Path(dct['feature_root'])
# Max proposals to be considered
# By default it is 10 * 100
self.num_frms = self.cfg.ds.num_sampled_frm
self.num_prop_per_frm = dct['num_prop_per_frm']
self.comm.num_prop_per_frm = self.num_prop_per_frm
self.max_proposals = self.num_prop_per_frm * self.num_frms
# Assert h5 file to read from exists
assert self.proposal_h5.exists()
# Assert region features exists
assert self.feature_root.exists()
# Assert rgb, motion features exists
self.seg_feature_root = Path(self.cfg.ds.seg_feature_root)
assert self.seg_feature_root.exists()
# Which proposals to be included
self.prop_thresh = self.cfg.misc.prop_thresh
self.exclude_bgd_det = self.cfg.misc.exclude_bgd_det
# Assert raw caption file (from activity net captions) exists
self.raw_caption_file = Path(self.cfg.ds.anet_cap_file)
assert self.raw_caption_file.exists()
# Assert act ent caption file with bbox exists
self.anet_ent_annot_file = Path(self.cfg.ds.anet_ent_annot_file)
assert self.anet_ent_annot_file.exists()
# Assert word vocab files exist
self.dic_anet_file = Path(self.cfg.ds.anet_ent_split_file)
assert self.dic_anet_file.exists()
# Max gt box to consider
# should consider all, set high
self.max_gt_box = self.cfg.ds.max_gt_box
# temporal attention size
self.t_attn_size = self.cfg.ds.t_attn_size
# Sequence length
self.seq_length = self.cfg.ds.max_seq_length
def load_annotations(self):
"""
Process the annotation file.
"""
# Load annotation files
self.annots = pd.read_csv(self.ann_file)
# Load raw captions
with open(self.raw_caption_file) as f:
self.raw_caption = json.load(f)
# Load anet bbox
with open(self.anet_ent_annot_file) as f:
self.anet_ent_captions = json.load(f)
# Needs to exported as well
# Load dictionaries
with open(self.dic_anet_file) as f:
self.comm.dic_anet = json.load(f)
# Get detections to index
self.comm.dtoi = {w: i+1 for w,
i in self.comm.dic_anet['wtod'].items()}
self.comm.itod = {i: w for w, i in self.comm.dtoi.items()}
self.comm.itow = self.comm.dic_anet['ix_to_word']
self.comm.wtoi = {w: i for i, w in self.comm.itow.items()}
self.comm.vocab_size = len(self.comm.itow) + 1
self.comm.detect_size = len(self.comm.itod)
def __len__(self):
return len(self.annots) #
# return 50
def __getitem__(self, idx: int):
return self.itemgetter(idx)
def pad_words_with_vocab(
self, out_list,
voc=None, pad_len=-1, defm=[1]):
"""
Input is a list.
If curr_len < pad_len: pad remaining with default value
Instead, if cur_len > pad_len: trim the input
"""
curr_len = len(out_list)
if pad_len == -1 or curr_len == pad_len:
return out_list
else:
if curr_len > pad_len:
return out_list[:pad_len]
else:
if voc is not None and hasattr(voc, 'itos'):
assert voc.itos[1] == '<pad>'
out_list += defm * (pad_len - curr_len)
return out_list
def get_props(self, index: int):
"""
Returns the padded proposals, padded mask, number of proposals
by reading the h5 files
"""
num_proposals = int(self.num_proposals[index])
label_proposals = self.label_proposals[index]
proposals = copy.deepcopy(label_proposals[:num_proposals, :])
# proposal mask to filter out low-confidence proposals or backgrounds
# mask is 1 if proposal is included
pnt_mask = (proposals[:, 6] >= self.prop_thresh)
if self.exclude_bgd_det:
pnt_mask &= (proposals[:, 5] != 0)
num_props = min(proposals.shape[0], self.max_proposals)
padded_props = self.pad_words_with_vocab(
proposals.tolist(), pad_len=self.max_proposals, defm=[[0]*7])
padded_mask = self.pad_words_with_vocab(
pnt_mask.tolist(), pad_len=self.max_proposals, defm=[0])
return np.array(padded_props), np.array(padded_mask), num_props
def get_features(self, vid_seg_id: str, num_proposals: int, props):
"""
Returns the region features, rgb-motion features
"""
vid_id_ix, seg_id_ix = vid_seg_id.split('_segment_')
seg_id_ix = str(int(seg_id_ix))
region_feature_file = self.feature_root / f'{vid_seg_id}.npy'
region_feature = np.load(region_feature_file)
region_feature = region_feature.reshape(
-1,
region_feature.shape[2]
).copy()
assert(num_proposals == region_feature.shape[0])
if self.cfg.misc.add_prop_to_region:
region_feature = np.concatenate(
[region_feature, props[:num_proposals, :5]],
axis=1
)
# load the frame-wise segment feature
seg_rgb_file = self.seg_feature_root / f'{vid_id_ix[2:]}_resnet.npy'
seg_motion_file = self.seg_feature_root / f'{vid_id_ix[2:]}_bn.npy'
assert seg_rgb_file.exists() and seg_motion_file.exists()
seg_rgb_feature = np.load(seg_rgb_file)
seg_motion_feature = np.load(seg_motion_file)
seg_feature_raw = np.concatenate(
(seg_rgb_feature, seg_motion_feature), axis=1)
return region_feature, seg_feature_raw
def get_frm_mask(self, proposals, gt_bboxs):
"""
1 where proposals and gt box don't match
0 where they match
We are basically matching the frame indices,
that is 1 where they belong to different frames
0 where they belong to same frame.
In mdl_bbox_utils.py -> bbox_overlaps_batch
frm_mask ~= frm_mask is used.
(We have been tricked, we have been backstabbed,
quite possibly bamboozled)
"""
# proposals: num_pps
# gt_bboxs: num_box
num_pps = proposals.shape[0]
num_box = gt_bboxs.shape[0]
return (np.tile(proposals.reshape(-1, 1), (1, num_box)) != np.tile(
gt_bboxs, (num_pps, 1)))
def get_seg_feat_for_frms(self, seg_feats, timestamps, duration, idx=None):
"""
Given seg features of shape num_frms x 3072
converts to 10 x 3072
Here 10 is the number of frames used by the mdl
timestamps contains the start and end time of the clip
duration is the total length of the video
note that end-st != dur, since one is for the clip
other is for the video
Additionally returns average over the timestamps
"""
# ctx is the context of the optical flow used
# 10 means 5 seconds previous, to 5 seconds after
# This is because optical flow is calculated at
# 2fps
ctx = self.cfg.misc.ctx_for_seg_feats
if timestamps[0] > timestamps[1]:
# something is wrong in AnetCaptions dataset
# since only 2 have problems, ignore
# print(idx, 'why')
timestamps = timestamps[1], timestamps[0]
st_time = timestamps[0]
end_time = timestamps[1]
duration_clip = end_time - st_time
num_frms = seg_feats.shape[0]
frm_ind = np.arange(0, 10)
frm_time = st_time + (duration_clip / 10) * (frm_ind + 0.5)
# *2 because of sampling at 2fps
frm_index_in_seg_feat = np.minimum(np.maximum(
(frm_time*2).astype(np.int_)-1, 0), num_frms-1)
st_indices = np.maximum(frm_index_in_seg_feat - ctx - 1, 0)
end_indices = np.minimum(frm_index_in_seg_feat + ctx + 1, num_frms)
if not st_indices[0] == end_indices[-1]:
try:
seg_feats_frms_glob = seg_feats[st_indices[0]:end_indices[-1]].mean(
axis=0)
except RuntimeWarning:
import pdb
pdb.set_trace()
else:
print(f'clip duration: {duration_clip}')
seg_feats_frms_glob = seg_feats[st_indices[0]]
assert np.all(end_indices - st_indices > 0)
try:
if ctx != 0:
seg_feats_frms = np.vstack([
seg_feats[sti:endi, :].mean(axis=0)
for sti, endi in zip(st_indices, end_indices)])
else:
seg_feats_frms = seg_feats[frm_index_in_seg_feat]
except RuntimeWarning:
import pdb
pdb.set_trace()
pass
return seg_feats_frms, seg_feats_frms_glob
def get_gt_annots(self, caption_dct: Dict, idx: int):
gt_bboxs = torch.tensor(caption_dct['bbox']).float()
gt_frms = torch.tensor(caption_dct['frm_idx']).unsqueeze(-1).float()
assert len(gt_bboxs) == len(gt_frms)
num_box = len(gt_bboxs)
gt_bboxs_t = torch.cat([gt_bboxs, gt_frms], dim=-1)
padded_gt_bboxs = self.pad_words_with_vocab(
gt_bboxs_t.tolist(),
pad_len=self.max_gt_box,
defm=[[0]*5]
)
padded_gt_bboxs_mask_list = [1] * num_box
padded_gt_box_mask = self.pad_words_with_vocab(
padded_gt_bboxs_mask_list,
pad_len=self.max_gt_box,
defm=[0]
)
return {
'padded_gt_bboxs': np.array(padded_gt_bboxs),
'padded_gt_box_mask': np.array(padded_gt_box_mask),
'num_box': num_box
}
def simple_item_getter(self, idx: int):
"""
Basically, this returns stuff for the
vid_seg_id obtained from the idx
"""
row = self.annots.iloc[idx]
vid_id = row['vid_id']
seg_id = str(row['seg_id'])
vid_seg_id = row['id']
ix = row['Index']
# Get the padded proposals, proposal masks and the number of proposals
padded_props, pad_pnt_mask, num_props = self.get_props(ix)
# Get the region features and the segment features
# Region features are for spatial stuff
# Segment features are for temporal stuff
region_feature, seg_feature_raw = self.get_features(
vid_seg_id, num_proposals=num_props, props=padded_props
)
# not accurate, with minor misalignments
# Get the time stamp information for each segment
timestamps = self.raw_caption[vid_id]['timestamps'][int(seg_id)]
# Get the durations for each time stamp
dur = self.raw_caption[vid_id]['duration']
# Get the number of frames in the segment
num_frm = seg_feature_raw.shape[0]
# basically time stamps.
# Not really used, kept for legacy reasons
sample_idx = np.array(
[
np.round(num_frm*timestamps[0]*1./dur),
np.round(num_frm*timestamps[1]*1./dur)
]
)
sample_idx = np.clip(np.round(sample_idx), 0,
self.t_attn_size).astype(int)
# Get segment features based on the number of frames used
seg_feature = np.zeros((self.t_attn_size, seg_feature_raw.shape[1]))
seg_feature[:min(self.t_attn_size, num_frm)
] = seg_feature_raw[:self.t_attn_size]
# gives both local and global features.
# In model can choose either one
seg_feature_for_frms, seg_feature_for_frms_glob = (
self.get_seg_feat_for_frms(
seg_feature_raw, timestamps, dur, idx)
)
# get gt annotations
# Get the a AE annotations
caption_dct = self.anet_ent_captions[vid_id]['segments'][seg_id]
# get the groundtruth_box annotations
gt_annot_dict = self.get_gt_annots(caption_dct, idx)
# extract the padded gt boxes
pad_gt_bboxs = gt_annot_dict['padded_gt_bboxs']
# store the number of gt boxes
num_box = gt_annot_dict['num_box']
# frame mask is NxM matrix of which proposals
# lie in the same frame of groundtruth
frm_mask = self.get_frm_mask(
padded_props[:num_props, 4], pad_gt_bboxs[:num_box, 4]
)
# pad it
pad_frm_mask = np.ones((self.max_proposals, self.max_gt_box))
pad_frm_mask[:num_props, :num_box] = frm_mask
pad_pnt_mask = torch.tensor(pad_pnt_mask).long()
# pad region features
pad_region_feature = np.zeros(
(self.max_proposals, region_feature.shape[1]))
pad_region_feature[:num_props] = region_feature[:num_props]
out_dict = {
# segment features
'seg_feature': torch.from_numpy(seg_feature).float(),
# local segment features
'seg_feature_for_frms': torch.from_numpy(
seg_feature_for_frms).float(),
# global segment features
'seg_feature_for_frms_glob': torch.from_numpy(
seg_feature_for_frms_glob).float(),
# number of proposals
'num_props': torch.tensor(num_props).long(),
# number of groundtruth boxes
'num_box': torch.tensor(num_box).long(),
# padded proposals
'pad_proposals': torch.tensor(padded_props).float(),
# padded groundtruth boxes
'pad_gt_bboxs': torch.tensor(pad_gt_bboxs).float(),
# padded groundtruth mask, not used, kept for legacy
'pad_gt_box_mask': torch.tensor(
gt_annot_dict['padded_gt_box_mask']).byte(),
# segment id, not used, kept for legacy
'seg_id': torch.tensor(int(seg_id)).long(),
# idx, ann_idx are same correspond to
# it is the index of vid_seg in the ann_file
'idx': torch.tensor(idx).long(),
'ann_idx': torch.tensor(idx).long(),
# padded region features
'pad_region_feature': torch.tensor(pad_region_feature).float(),
# padded frame mask
'pad_frm_mask': torch.tensor(pad_frm_mask).byte(),
# padded proposal mask
'pad_pnt_mask': pad_pnt_mask.byte(),
# sample number, not used, legacy
'sample_idx': torch.tensor(sample_idx).long(),
}
return out_dict
class AnetVerbDataset(AnetEntDataset):
"""
The basic ASRL dataset.
All outputs for one query
"""
def fix_via_ast(self, df: DF):
"""
ASRL csv has columns containing list
which are read as strings.
so [1,2] is read as "[1,2]"
This is fixed using ast.literal_eval
which would convert the string to a list/dict
depending on the input
"""
for k in df.columns:
first_word = df.iloc[0][k]
if isinstance(first_word, str) and (first_word[0] in '[{'):
df[k] = df[k].apply(
lambda x: ast.literal_eval(x))
return df
def __len__(self):
return len(self.srl_annots)
def pidx2list(self, pidx):
"""
Converts process_idx2 to single list
Just a convenience function required
because some places it is list,
some places it isn't
"""
lst = []
for p1 in pidx:
if not isinstance(p1, list):
p1 = [p1]
for p2 in p1:
if not isinstance(p2, list):
p2 = [p2]
for p3 in p2:
assert not isinstance(p3, list)
lst.append(p3)
return lst
def get_srl_anns(self, srl_row, out=None):
"""
To output dictionary of whatever srl needs
1. tags
2. args with st, end ixs
3. box ind matching
This is a pretty detailed function, and
really requires patience to understand.
I know I know. Forgive me.
"""
srl_row = copy.deepcopy(srl_row)
def word_to_int_vocab(words, voc, pad_len=-1):
"""
A convenience function to convert words
into their indices given a vocab.
Using Anet Vocab only.
Optionally, pad answers as well
"""
out_list = []
if hasattr(voc, 'stoi'):
vocs = voc.stoi
else:
vocs = voc
for w in words:
if w in vocs:
out_list.append(int(vocs[w]))
else:
if hasattr(voc, 'UNK'):
unk_word = voc.UNK
else:
unk_word = 'UNK'
out_list.append(int(vocs[unk_word]))
curr_len = len(out_list)
return self.pad_words_with_vocab(out_list,
voc, pad_len=pad_len), curr_len
# srl args to worry about
vis_set = self.cfg.ds.include_srl_args
# want to get the arguments and the word indices
# req_pat_ix: [['ARG0', [0,1,2,3]], ...]
# srl_args = ['ARG0', 'V', ...]
# srl_words_inds = [[0,1,2,3], ...]
srl_args, srl_words_inds = [list(t) for t in zip(*srl_row.req_pat_ix)]
# simple mask to care only about those in srl_set
# also pad them
srl_args_visual_msk = self.pad_words_with_vocab(
[s in vis_set for s in srl_args],
pad_len=self.srl_arg_len, defm=[0]
)
# get the words from their indices
# convert to words
# if original sentence is 'A child playing tennis'
# [[0,1], ...] -> [['A', 'child'],...]
srl_arg_words = [[srl_row.words[ix]
for ix in y] for y in srl_words_inds]
# Tags are converted via tag vocab
# not used, kept for legacy
tag_seq = [srl_row.tags[ix] for y in srl_words_inds for ix in y]
tag_word_ind, _ = word_to_int_vocab(
# srl_row.tags,
tag_seq,
self.arg_vocab['arg_tag_vocab'],
pad_len=self.seq_length
)
# Argument Names (ARG0/V/) are converted to indices
# Max num of arguments is kept to be self.srl_arg_len
# very few cases
assert 'V' in srl_args
verb_ind_in_srl = srl_args.index('V')
if not verb_ind_in_srl <= self.srl_arg_len - 1:
verb_ind_in_srl = 0
# Use the argument vocab created earlier
# convert the arguments to indices using the vocab
srl_arg_inds, srl_arg_len = word_to_int_vocab(
srl_args, self.arg_vocab['arg_vocab'],
pad_len=self.srl_arg_len
)
if srl_arg_len > self.srl_arg_len:
srl_arg_len = self.srl_arg_len
# defm: is the default matrix to be used
defm = tuple([[1] * self.seq_length, 0])
# convert the words to their indices using the vocab
# for every argument
# the vocab here is self.comm.wtoi obtained from AE
srl_arg_words_ind_length = self.pad_words_with_vocab(
[word_to_int_vocab(
srl_arg_w, self.comm.wtoi, pad_len=self.seq_length) for
srl_arg_w in srl_arg_words],
pad_len=self.srl_arg_len, defm=[defm]
)
# Unzip to get the word indices and their lengths for
# each argument separately
srl_arg_words_ind, srl_arg_words_length = [
list(t) for t in zip(*srl_arg_words_ind_length)]
# This is used to convert
# [[ARG0: w1,w2], [ARG1: w5,..]] ->
# [w1,w2,w5]
# Basically, convert
# [0] 0,1 -> 0,1
# [1] 0,1 -> 40, 41
# and so on
# Finally, use this with index_select
# in the mdl part
srl_arg_word_list = [
torch.arange(0+st, 0+st+wlen)
for st, wlen in zip(
range(
0,
self.seq_length*self.srl_arg_len,
self.seq_length), srl_arg_words_length)
]
# Concat above list
srl_arg_words_list = torch.cat(srl_arg_word_list, dim=0).tolist()
# Create the mask to be used with index select
srl_arg_words_mask = self.pad_words_with_vocab(
srl_arg_words_list, pad_len=self.seq_length, defm=[-1]
)
# Get the start and end positions
# these are used to retrieve
# LSTM outputs of the sentence
# to the argument vectors
srl_arg_word_list_tmp = [
0] + torch.cumsum(
torch.tensor(srl_arg_words_length),
dim=0).tolist()
srl_arg_words_capture = [
(min(x, self.seq_length-1), min(y-1, self.seq_length-1))
if wlen > 0 else (0, 0)
for x, y, wlen in zip(
srl_arg_word_list_tmp[:-1],
srl_arg_word_list_tmp[1:],
srl_arg_words_length
)
]
# This is used to retrieve in argument form from
# the sentence form
# Basically, [w1,w2,w5] -> [[ARG0: w1,w2], [ARG1: w5]]
# Restrict to max len because scatter is used later
srl_arg_words_map_inv = [
y_ix for y_ix, y in enumerate(
srl_words_inds[:self.srl_arg_len]) for ix in y]
# Also, pad it
srl_arg_words_map_inv = self.pad_words_with_vocab(
srl_arg_words_map_inv,
pad_len=self.seq_length,
defm=[0]
)
# The following creates a binary mask for the sequence_length
# [1] * seq_cnt for every ARG row
# This is applied to the scatter output
defm = [[0] * self.seq_length]
seq_cnt = sum(srl_arg_words_length)
srl_arg_words_binary_mask = self.pad_words_with_vocab(
[self.pad_words_with_vocab(
[1]*seq_cnt, pad_len=self.seq_length, defm=[0])
for srl_arg_w in srl_arg_words],
pad_len=self.srl_arg_len, defm=defm)
# Get the set of visual words
vis_idxs_set = set(self.pidx2list(srl_row.process_idx2))
# Create a map for getting which are the visual words
srl_arg_words_conc_ix = [ix for y in srl_words_inds for ix in y]
# Create the binary mask
srl_vis_words_binary_mask = self.pad_words_with_vocab(
[1 if srl_vw1 in vis_idxs_set else 0
for srl_vw1 in srl_arg_words_conc_ix],
pad_len=self.seq_length, defm=[0])
# The following are used to map the gt boxes
# The first is the srl argument, followed by an
# indicator wheather the box is valid or not
# third is if valid which boxes to look at
srl_args, srl_arg_box_indicator, srl_arg_box_inds = [
list(t) for t in zip(*srl_row.req_cls_pats_mask)
]
# srl boxes, and their lengths are stored in a list
srl_boxes = []
srl_boxes_lens = []
for s1_ind, s1 in enumerate(srl_arg_box_inds):
mult = min(
len(s1),
self.box_per_srl_arg
) if srl_arg_box_indicator[s1_ind] == 1 else 0
s11 = [x if x_ix < self.box_per_srl_arg else 0 for x_ix,
x in enumerate(s1)]
srl_boxes.append(self.pad_words_with_vocab(
s11, pad_len=self.box_per_srl_arg, defm=[0]))
srl_boxes_lens.append(self.pad_words_with_vocab(
[1]*mult, pad_len=self.box_per_srl_arg, defm=[0]))
# They are then padded
srl_boxes = self.pad_words_with_vocab(
srl_boxes,
pad_len=self.srl_arg_len,
defm=[[0]*self.box_per_srl_arg]
)
srl_boxes_lens = self.pad_words_with_vocab(
srl_boxes_lens,
pad_len=self.srl_arg_len,
defm=[[0]*self.box_per_srl_arg]
)
# An indicator wheather the boxes are valid
srl_arg_boxes_indicator = self.pad_words_with_vocab(
srl_arg_box_indicator, pad_len=self.srl_arg_len, defm=[0])
out_dict = {
# Tags are indexed (B-V -> 4)
'srl_tag_word_ind': torch.tensor(tag_word_ind).long(),
# Tag word len available elsewhere, hence removed
# 'tag_word_len': torch.tensor(tag_word_len).long(),
# 1 if arg is in ARG1-2/LOC else 0
'srl_args_visual_msk': torch.tensor(srl_args_visual_msk).long(),
# ARGs are indexed (ARG0 -> 4, V -> 2)
'srl_arg_inds': torch.tensor(srl_arg_inds).long(),
# How many args are considered (ARG0, V,ARG1, ARGM), would be 4
'srl_arg_len': torch.tensor(srl_arg_len).long(),
# the above but in mask format
'srl_arg_inds_msk': torch.tensor(
[1] * srl_arg_len + [0]*(self.srl_arg_len - srl_arg_len)
).long(),
# Where the verb is located, in prev eg, it would be 1
'verb_ind_in_srl': torch.tensor(verb_ind_in_srl).long(),
# num_srl_args x seq_len: for each srl_arg, what is the seq
# so ARG0: The woman -> [[1946, 4307, ...],...]
'srl_arg_words_ind': torch.tensor(srl_arg_words_ind).long(),
# The corresponding lengths of each num_srl
'srl_arg_words_length': torch.tensor(srl_arg_words_length).long(),
# num_srl_args x seq_len, 1s upto the seq_len of the whole
# srl_sent: This is used in scatter operation
'srl_arg_words_binary_mask': torch.tensor(
srl_arg_words_binary_mask).long(),
# Similar to previous, but 1s only at places
# which are visual words. Used for scatter + GVD
'srl_vis_words_binary_mask': torch.tensor(
srl_vis_words_binary_mask).long(),
# seq_len, but contains in the indices to be gathered
# from num_srl_args x seq_len -> num_srl_args*seq_len
# via index_select
'srl_arg_word_mask': torch.tensor(srl_arg_words_mask).long(),
# seq_len basically
'srl_arg_word_mask_len': torch.tensor(min(sum(
srl_arg_words_length), self.seq_length)).long(),
# containing start and end points of the words to be collected
'srl_arg_words_capture': torch.tensor(srl_arg_words_capture).long(),
# used scatter + GVD
'srl_arg_words_map_inv': torch.tensor(srl_arg_words_map_inv).long(),
# box indices in gt boxes
'srl_boxes': torch.tensor(srl_boxes).long(),
# mask on which of them to choose
'srl_boxes_lens': torch.tensor(srl_boxes_lens).long(),
'srl_arg_boxes_mask': torch.tensor(srl_arg_boxes_indicator).long()
}
return out_dict
def collate_dict_list(self, dict_list, pad_len=None):
"""
Convert List[Dict[key, val]] -> Dict[key, List[val]]
Also, pad so that can obtain Dict[key, tensor]
"""
out_dict = {}
keys = list(dict_list[0].keys())
num_dl = len(dict_list)
if pad_len is None:
pad_len = self.max_srl_in_sent
for k in keys:
dl_list = [dl[k] for dl in dict_list]
dl_list_pad = self.pad_words_with_vocab(
dl_list,
pad_len=pad_len, defm=[dl_list[0]])
out_dict[k] = torch.stack(dl_list_pad)
return out_dict, num_dl
def sent_item_getter(self, idx):
"""
get vidseg at a time, multiple verbs
Basically, input is a vid_seg, which may contain
multiple verbs.
No longer used, kept for legacy
"""
ann_ind, srl_rows = self.srl_annots[idx]
out = self.simple_item_getter(ann_ind)
out_dict_list = [self.get_srl_anns(srl_rows.iloc[ix], out)
for ix in range(len(srl_rows))]
srl_row_indices = self.pad_words_with_vocab(
srl_rows.index.tolist(),
pad_len=self.max_srl_in_sent)
out_dict, num_verbs = self.collate_dict_list(out_dict_list)
out_dict['num_verbs'] = torch.tensor(num_verbs).long()
out_dict['ann_idx'] = torch.tensor(ann_ind).long()
out_dict['sent_idx'] = torch.tensor(idx).long()
out_dict['srl_verb_idxs'] = torch.tensor(srl_row_indices).long()
out.update(out_dict)
return out
def get_for_one_verb(self, srl_row, idx, out=None):
"""
One ASRL index, not used, kept for legacy
"""
out_dict_list = [self.get_srl_anns(srl_row, out)]
out_dict, num_verbs = self.collate_dict_list(out_dict_list)
out_dict['num_verbs'] = torch.tensor(num_verbs).long()
out_dict['ann_idx'] = torch.tensor(srl_row.ann_ind).long()
out_dict['sent_idx'] = torch.tensor(idx).long()
out_dict['srl_verb_idxs'] = torch.tensor([idx]).long()
return out_dict
def verb_item_getter(self, idx):
"""
get verb items, one at a time
kept for legacy
"""
srl_row = self.srl_annots.loc[idx]
out = self.simple_item_getter(srl_row.ann_ind)
out_dict = self.get_for_one_verb(srl_row, idx, out)
out.update(out_dict)
return out
class AV_CS:
"""
Basically performs CS with SEP/TEMP/SPAT
It is kept as a separate class
This allows for modularity, and one could replace
the parent dataset class for a different dataset
"""
def __len__(self):
return len(self.srl_annots)
def after_init(self):
"""
Select the SRL annotation file to choose
As well as the dictionary for CS
"""
if self.split_type == 'train':
srl_annot_file = self.cfg.ds.trn_ds4_inds
arg_dict_file = self.cfg.ds.trn_ds4_dicts
elif self.split_type == 'valid' or self.split_type == 'test':
srl_annot_file = self.cfg.ds.val_ds4_inds
arg_dict_file = self.cfg.ds.val_ds4_dicts
else:
raise NotImplementedError
# Read the file
self.srl_annots = pd.read_csv(srl_annot_file)
assert hasattr(self, 'srl_annots')
# Convert columns to List/Dict
self.srl_annots = self.fix_via_ast(self.srl_annots)
# Open the arg dict for CS
with open(arg_dict_file) as f:
self.arg_dicts = json.load(f)
# for now, we only consider the case
# with one verb at a time
self.max_srl_in_sent = 1
# In training allow, for CS, Random
# or CS+Random
# The last one doesn't make sense in Val/Test
if self.split_type == 'train':
self.sample_type = self.cfg.ds.trn_sample
assert self.sample_type in set(['ds4', 'random', 'ds4_random'])
elif self.split_type == 'valid' or self.split_type == 'test':
self.sample_type = self.cfg.ds.val_sample
assert self.sample_type in set(['ds4', 'random'])
else:
raise NotImplementedError
# Use sample type to decide which functions to use
if self.sample_type == 'random':
self.more_idx_collector = getattr(self, 'get_random_more_idx')
elif self.sample_type == 'ds4':
self.more_idx_collector = getattr(self, 'get_cs_more_idxs')
elif self.sample_type == 'ds4_random':
self.more_idx_collector = getattr(
self, 'get_cs_and_random_more_idx')
else:
raise NotImplementedError
# Number of Videos to Use for CS
if self.split_type == 'train':
nvids_sample = self.cfg.ds.trn_num_vid_sample
elif self.split_type in set(['valid', 'test']):
nvids_sample = self.cfg.ds.val_num_vid_sample
else:
raise NotImplementedError
# set number of videos to use
self.cs_nvids_sample = nvids_sample
# itemcollector basically collects
# nvid samples
self.itemcollector = getattr(
self, 'verb_item_getter_nvid'
)
# depending on conc_type choose the
# __getitem__ function
# append_everywhere is only used for SEP
# which appends the lang stuff to each sample
# this makes the code cleaner
# if svsq, then do same as sep,
# and set nvids_sample = 1
if self.cfg.ds.conc_type == 'spat':
self.itemgetter = getattr(
self, 'verb_item_getter_SPAT')
self.append_everywhere = False
elif self.cfg.ds.conc_type == 'temp':
self.itemgetter = getattr(
self, 'verb_item_getter_TEMP')
self.append_everywhere = False
elif self.cfg.ds.conc_type == 'sep':
self.itemgetter = getattr(
self, 'verb_item_getter_SEP')
self.append_everywhere = True
elif self.cfg.ds.conc_type == 'svsq':
self.itemgetter = getattr(
self, 'verb_item_getter_SEP')
self.append_everywhere = True
self.cs_nvids_sample = 1
else:
raise NotImplementedError
# Whether to shuffle among the four screens
# Has to be True. Keep false only for debugging
self.ds4_shuffle = self.cfg.ds.cs_shuffle
# open the vocab files for args
with open(self.cfg.ds.arg_vocab_file, 'rb') as f:
self.arg_vocab = pickle.load(f)
# set the max number of SRLs
# ARG0, V, ARG1 => 3 SRLs
self.srl_arg_len = self.cfg.misc.srl_arg_length
# set the max number of boxes for each SRL
# ARG0: four people => 4 boxes
self.box_per_srl_arg = self.cfg.misc.box_per_srl_arg
def get_cs_and_random_more_idx(self, idx):
"""
Either choose at random or
choose via CS with uniform probability
"""
if np.random.random() < 0.5:
return self.get_random_more_idx(idx)
else:
return self.get_cs_more_idxs(idx)
def get_random_more_idx(self, idx):
"""
Returns set of random idxs
"""
if self.split_type == 'train':
# for train, generate this list at runtime
more_idxs, _ = create_random_list(
self.cfg,
self.srl_annots,
idx
)
if len(more_idxs) > self.cs_nvids_sample - 1:
more_idxs_new_keys = np.random.choice(
list(more_idxs.keys()),
min(len(more_idxs), self.cs_nvids_sample-1),
replace=False
)
more_idxs = {k: more_idxs[k] for k in more_idxs_new_keys}
elif self.split_type == 'valid' or self.split_type == 'test':
# for valid/test use pre-generated ones
# obtain predefined idxs