-
Notifications
You must be signed in to change notification settings - Fork 3
/
data.py
30 lines (24 loc) · 859 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from os.path import join
from torch.utils.data import Dataset
import torch
class Im2LatexDataset(Dataset):
def __init__(self, data_dir, split, max_len):
"""args:
data_dir: 存储预处理数据的根目录
split: train, validate or test
"""
assert split in ["train", "validate", "test"]
self.data_dir = data_dir
self.split = split
self.max_len = max_len
self.pairs = self._load_pairs()
def _load_pairs(self):
pairs = torch.load(join(self.data_dir, "{}.pkl".format(self.split)))
for i, (img, formula) in enumerate(pairs):
pair = (img, " ".join(formula.split()[:self.max_len]))
pairs[i] = pair
return pairs
def __getitem__(self, index):
return self.pairs[index]
def __len__(self):
return len(self.pairs)