-
Notifications
You must be signed in to change notification settings - Fork 7
/
eval_sirs_4000
75 lines (60 loc) · 2.86 KB
/
eval_sirs_4000
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
from os.path import join
import torch.backends.cudnn as cudnn
import data.sirs_dataset as datasets
from data.image_folder import read_fns
from engine import Engine
from options.net_options.train_options import TrainOptions
from tools import mutils
opt = TrainOptions().parse()
opt.isTrain = False
cudnn.benchmark = True
opt.no_log = True
opt.display_id = 0
opt.verbose = False
datadir = os.path.join(opt.base_dir, 'test')
# Define evaluation/test dataset
eval_dataset_real = datasets.DSRTestDataset(join(datadir, f'real20_420'),
fns=read_fns('data/real_test.txt'),
if_align=opt.if_align)
eval_dataset_postcard = datasets.DSRTestDataset(join(datadir, 'SIR2/PostcardDataset'),
if_align=opt.if_align)
eval_dataset_solidobject = datasets.DSRTestDataset(join(datadir, 'SIR2/SolidObjectDataset'),
if_align=opt.if_align)
eval_dataset_wild = datasets.DSRTestDataset(join(datadir, 'SIR2/WildSceneDataset'),
if_align=opt.if_align)
eval_dataset_nature = datasets.DSRTestDataset(join(datadir, 'Nature'), if_align=opt.if_align)
eval_dataloader_real = datasets.DataLoader(
eval_dataset_real, batch_size=1, shuffle=True,
num_workers=opt.nThreads, pin_memory=True)
eval_dataloader_solidobject = datasets.DataLoader(
eval_dataset_solidobject, batch_size=1, shuffle=False,
num_workers=opt.nThreads, pin_memory=True)
eval_dataloader_postcard = datasets.DataLoader(
eval_dataset_postcard, batch_size=1, shuffle=False,
num_workers=opt.nThreads, pin_memory=True)
eval_dataloader_wild = datasets.DataLoader(
eval_dataset_wild, batch_size=1, shuffle=False,
num_workers=opt.nThreads, pin_memory=True)
eval_dataloader_nature = datasets.DataLoader(
eval_dataset_nature, batch_size=1, shuffle=False,
num_workers=opt.nThreads, pin_memory=True)
engine = Engine(opt)
"""Main Loop"""
result_dir = os.path.join('./checkpoints', opt.name, mutils.get_formatted_time())
#
res = engine.eval(eval_dataloader_real, dataset_name='testdata_real',
savedir=join(result_dir, 'real20'), suffix='real20')
print(res)
res = engine.eval(eval_dataloader_solidobject, dataset_name='testdata_solidobject',
savedir=join(result_dir, 'solidobject'), suffix='solidobject')
print(res)
res = engine.eval(eval_dataloader_postcard, dataset_name='testdata_postcard',
savedir=join(result_dir, 'postcard'), suffix='postcard')
print(res)
res = engine.eval(eval_dataloader_wild, dataset_name='testdata_wild',
savedir=join(result_dir, 'wild'), suffix='wild')
print(res)
res = engine.eval(eval_dataloader_wild, dataset_name='testdata_nature',
savedir=join(result_dir, 'nature'), suffix='nature')
print(res)