forked from WeChat-Big-Data-Challenge-2022/challenge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_helper.py
138 lines (115 loc) · 5.48 KB
/
data_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json
import random
import zipfile
from io import BytesIO
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from transformers import BertTokenizer
from category_id_map import category_id_to_lv2id
def create_dataloaders(args):
dataset = MultiModalDataset(args, args.train_annotation, args.train_zip_feats)
size = len(dataset)
val_size = int(size * args.val_ratio)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [size - val_size, val_size],
generator=torch.Generator().manual_seed(args.seed))
train_sampler = RandomSampler(train_dataset)
val_sampler = SequentialSampler(val_dataset)
train_dataloader = DataLoader(train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
val_dataloader = DataLoader(val_dataset,
batch_size=args.val_batch_size,
sampler=val_sampler,
drop_last=False,
pin_memory=True,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
return train_dataloader, val_dataloader
class MultiModalDataset(Dataset):
""" A simple class that supports multi-modal inputs.
For the visual features, this dataset class will read the pre-extracted
features from the .npy files. For the title information, it
uses the BERT tokenizer to tokenize. We simply ignore the ASR & OCR text in this implementation.
Args:
ann_path (str): annotation file path, with the '.json' suffix.
zip_feats (str): visual feature zip file path.
test_mode (bool): if it's for testing.
"""
def __init__(self,
args,
ann_path: str,
zip_feats: str,
test_mode: bool = False):
self.max_frame = args.max_frames
self.bert_seq_length = args.bert_seq_length
self.test_mode = test_mode
# lazy initialization for zip_handler to avoid multiprocessing-reading error
self.zip_feat_path = zip_feats
self.handles = [None for _ in range(args.num_workers)]
# load annotations
with open(ann_path, 'r', encoding='utf8') as f:
self.anns = json.load(f)
# initialize the text tokenizer
self.tokenizer = BertTokenizer.from_pretrained(args.bert_dir, use_fast=True, cache_dir=args.bert_cache)
def __len__(self) -> int:
return len(self.anns)
def get_visual_feats(self, worker_id, idx: int) -> tuple:
# read data from zipfile
vid = self.anns[idx]['id']
if self.handles[worker_id] is None:
self.handles[worker_id] = zipfile.ZipFile(self.zip_feat_path, 'r')
raw_feats = np.load(BytesIO(self.handles[worker_id].read(name=f'{vid}.npy')), allow_pickle=True)
raw_feats = raw_feats.astype(np.float32) # float16 to float32
num_frames, feat_dim = raw_feats.shape
feat = np.zeros((self.max_frame, feat_dim), dtype=np.float32)
mask = np.ones((self.max_frame,), dtype=np.int32)
if num_frames <= self.max_frame:
feat[:num_frames] = raw_feats
mask[num_frames:] = 0
else:
# if the number of frames exceeds the limitation, we need to sample
# the frames.
if self.test_mode:
# uniformly sample when test mode is True
step = num_frames // self.max_frame
select_inds = list(range(0, num_frames, step))
select_inds = select_inds[:self.max_frame]
else:
# randomly sample when test mode is False
select_inds = list(range(num_frames))
random.shuffle(select_inds)
select_inds = select_inds[:self.max_frame]
select_inds = sorted(select_inds)
for i, j in enumerate(select_inds):
feat[i] = raw_feats[j]
feat = torch.FloatTensor(feat)
mask = torch.LongTensor(mask)
return feat, mask
def tokenize_text(self, text: str) -> tuple:
encoded_inputs = self.tokenizer(text, max_length=self.bert_seq_length, padding='max_length', truncation=True)
input_ids = np.array(encoded_inputs['input_ids'])
mask = np.array(encoded_inputs['attention_mask'])
return input_ids, mask
def __getitem__(self, idx: int) -> dict:
# Step 1, load visual features from zipfile.
worker_info = torch.utils.data.get_worker_info()
frame_input, frame_mask = self.get_visual_feats(worker_info.id, idx)
# Step 2, load title tokens
title_input, title_mask = self.tokenize_text(self.anns[idx]['title'])
# Step 3, summarize into a dictionary
data = dict(
frame_input=frame_input,
frame_mask=frame_mask,
title_input=title_input,
title_mask=title_mask
)
# Step 4, load label if not test mode
if not self.test_mode:
label = category_id_to_lv2id(self.anns[idx]['category_id'])
data['label'] = torch.LongTensor([label])
return data