-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset_test.py
75 lines (65 loc) · 3.17 KB
/
dataset_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import random
import os
import io
import torchvision
import torch
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import torch.nn.functional as F_tensor
import numpy as np
from torch.utils.data import DataLoader
class Dataset(object):
def __init__(self, data_dir, mask_dir, fold, input_size=[224, 224], normalize_mean=[0.485, 0.456, 0.406],
normalize_std=[0.229, 0.224, 0.225]):
self.data_dir = data_dir
self.mask_dir = mask_dir
self.fold = fold
self.pair_list = self.get_pair_list(fold=self.fold)
self.input_size = input_size
self.normalize_mean = normalize_mean
self.normalize_std = normalize_std
def get_pair_list(self, fold):
pair_list = []
cls_list = []
f = open(os.path.join(self.mask_dir, 'val', 'split%1d_val.txt' %fold))
line = f.readline()
while line:
sup_name, query_name, cat = line.split()
cat = int(cat)
pair_list.append([query_name, sup_name, cat])
line = f.readline()
return pair_list
def __getitem__(self, index):
query_name = self.pair_list[index][0]
support_name = self.pair_list[index][1]
class_name = self.pair_list[index][2] # random sample a class in this img
support_mask = Image.open(os.path.join(self.mask_dir, 'val', str(class_name), support_name+'.png')).convert('1')
support_img = Image.open(os.path.join(self.data_dir, 'JPEGImages', support_name+'.jpg')).convert("RGB")
query_mask = Image.open(os.path.join(self.mask_dir, 'val', str(class_name), query_name+'.png')).convert('1')
query_img = Image.open(os.path.join(self.data_dir, 'JPEGImages', query_name+'.jpg')).convert("RGB")
_, support_img, _, support_mask = self.image_process(self.input_size, support_img, support_mask)
query_img0, query_img1, query_img2, query_mask = self.image_process(self.input_size, query_img, query_mask)
return query_img0, query_img1, query_img2, query_mask, support_img, support_mask, class_name
def image_process(self, input_size, image, mask):
h, w =input_size
#h,w=image.size
resize=transforms.Resize(size=(h, w),interpolation=Image.NEAREST)
mask=resize(mask)
resize=transforms.Resize(size=(h, w),interpolation=Image.BILINEAR)
image0=resize(image)
# mutil-scale evaluation ([305, 305], [353,353], [473, 473])
resize=transforms.Resize(size=(305, 305),interpolation=Image.BILINEAR)
image1=resize(image)
resize=transforms.Resize(size=(473, 473),interpolation=Image.BILINEAR)
image2=resize(image)
image0 = TF.to_tensor(image0)
image0 = TF.normalize(image0, self.normalize_mean, self.normalize_std)
image1 = TF.to_tensor(image1)
image1 = TF.normalize(image1, self.normalize_mean, self.normalize_std)
image2 = TF.to_tensor(image2)
image2 = TF.normalize(image2, self.normalize_mean, self.normalize_std)
mask = TF.to_tensor(mask)
return image0, image1, image2, mask
def __len__(self):
return len(self.pair_list)