-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcollector.py
34 lines (27 loc) · 1.08 KB
/
collector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import pandas as pd
import torch
import easydict
from typing import Any, Dict, List
from dataclasses import dataclass
@dataclass
class MKGLDataCollector:
dataset: object = None
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
first = features[0]
batch_size = len(features)
batch = {}
for k, v in first.items():
batch[k] = [f[k] for f in features]
batch['input_text'] = batch['input_text']+batch['inv_input_text']
batch.update(self.dataset.tokenizer(batch['input_text'], padding=True))
batch['input_length'] = np.sum(batch['attention_mask'], axis=1)
split = batch['split'][0]
del batch['input_text'], batch['inv_input_text'], batch['split']
for k, v in batch.items():
batch[k] = torch.tensor(batch[k])
batch['split'] = split
if batch['split'] != 'train':
return {'batch': easydict.EasyDict(batch), 'label': torch.ones(batch_size, dtype=torch.bfloat16)}
else:
return {'batch': easydict.EasyDict(batch)}