-
Notifications
You must be signed in to change notification settings - Fork 49
/
preprocess.py
54 lines (41 loc) · 1.62 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from os.path import join
import argparse
from PIL import Image
from torchvision import transforms
import torch
def preprocess(data_dir, split):
assert split in ["train", "validate", "test"]
print("Process {} dataset...".format(split))
images_dir = join(data_dir, "formula_images_processed")
formulas_file = join(data_dir, "im2latex_formulas.norm.lst")
with open(formulas_file, 'r') as f:
formulas = [formula.strip('\n') for formula in f.readlines()]
split_file = join(data_dir, "im2latex_{}_filter.lst".format(split))
pairs = []
transform = transforms.ToTensor()
with open(split_file, 'r') as f:
for line in f:
img_name, formula_id = line.strip('\n').split()
# load img and its corresponding formula
img_path = join(images_dir, img_name)
img = Image.open(img_path)
img_tensor = transform(img)
formula = formulas[int(formula_id)]
pair = (img_tensor, formula)
pairs.append(pair)
pairs.sort(key=img_size)
out_file = join(data_dir, "{}.pkl".format(split))
torch.save(pairs, out_file)
print("Save {} dataset to {}".format(split, out_file))
def img_size(pair):
img, formula = pair
return tuple(img.size())
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Im2Latex Data Preprocess Program")
parser.add_argument("--data_path", type=str,
default="./data/", help="The dataset's dir")
args = parser.parse_args()
splits = ["validate", "test", "train"]
for s in splits:
preprocess(args.data_path, s)