-
Notifications
You must be signed in to change notification settings - Fork 2
/
inference.py
122 lines (103 loc) · 3.87 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize
class TestDataset(Dataset):
def __init__(self, img_paths, transform):
self.img_paths = img_paths
self.transform = transform
def __getitem__(self, index):
image = Image.open(self.img_paths[index])
if self.transform:
image = self.transform(image)
return image
def __len__(self):
return len(self.img_paths)
def main(config):
device = torch.device('cuda')
# meta 데이터와 이미지 경로를 불러옵니다.
submission = pd.read_csv(os.path.join(config.test_dir, 'info.csv'))
image_dir = os.path.join(config.test_dir, 'images')
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]
# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
transform = transforms.Compose([
Resize(config.resize, Image.BILINEAR),
ToTensor(),
Normalize(mean=(0.548, 0.504, 0.497), std=(0.237, 0.247, 0.246))
])
dataset = TestDataset(image_paths, transform)
loader = DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=False
)
# 모델을 정의합니다. (학습한 모델이 있다면 torch.load로 모델을 불러주세요!)
model_module = getattr(module_arch, config.model)
model = model_module(num_classes=18).to(device)
model.load_state_dict(torch.load(config.model_path, map_location=device))
model.eval()
# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in tqdm(loader):
with torch.no_grad():
images = images.to(device)
if config.multi_head:
pred = model(images)
pred_mask, pred_gender, pred_age = pred
pred = torch.argmax(pred_mask, dim=-1) * 6 + torch.argmax(pred_gender, dim=-1) * 3 + torch.argmax(pred_age, dim=-1)
all_predictions.extend(pred.cpu().numpy())
else:
pred = model(images)
pred = pred.argmax(dim=-1)
all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions
# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(config.test_dir, 'submission.csv'), index=False)
print('test inference is done!')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Template')
parser.add_argument(
"--model", type=str, default="EfficientNetB0MultiHead", help="model type (default: EfficientNetB0MultiHead)"
)
parser.add_argument(
"--test_dir",
type=str,
default=os.environ.get("SM_CAHNNEL_EVAL", "/data/ephemeral/maskdata/eval")
)
parser.add_argument(
"--resize",
nargs=2,
type=int,
default=[128, 96],
help="resize size for image when training",
)
parser.add_argument(
"--multi_head",
type=bool,
default=True,
help="모델의 head가 1개(num_classes=18)인 경우 False, 3개인 경우 True"
)
parser.add_argument(
"--model_path",
type=str,
default="/data/ephemeral/home/model/exp/best.pth",
help="사용할 모델의 weight 경로를 입력해주세요 (예: /data/ephemeral/home/model/exp/best.pth)"
)
parser.add_argument(
"--batch_size",
type=int,
default=100,
help="input batch size for validing (default: 1000)",
)
args = parser.parse_args()
print(args)
main(args)