-
Notifications
You must be signed in to change notification settings - Fork 8
/
test.py
146 lines (119 loc) · 4.99 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# -*- coding: utf-8 -*-
import argparse
import json
import os
import os.path
import numpy as np
import torch
from tqdm import tqdm
from utils import builder, configurator, io, misc, ops, pipeline, recorder
def parse_config():
parser = argparse.ArgumentParser("Training and evaluation script")
parser.add_argument("--config", default="./configs/MFFN/MFFN_R50.py", type=str)
parser.add_argument("--datasets-info", default="./configs/_base_/dataset/dataset_configs.json", type=str)
parser.add_argument("--model-name", type=str)
parser.add_argument("--batch-size", type=int)
parser.add_argument("--load-from", type=str)
parser.add_argument("--save-path", type=str)
parser.add_argument("--minmax-results", action="store_true")
parser.add_argument("--info", type=str)
args = parser.parse_args()
config = configurator.Configurator.fromfile(args.config)
config.use_ddp = False
if args.model_name is not None:
config.model_name = args.model_name
if args.batch_size is not None:
config.test.batch_size = args.batch_size
if args.load_from is not None:
config.load_from = args.load_from
if args.info is not None:
config.experiment_tag = args.info
if args.save_path is not None:
if os.path.exists(args.save_path):
if len(os.listdir(args.save_path)) != 0:
raise ValueError(f"--save-path is not an empty folder.")
else:
print(f"{args.save_path} does not exist, create it.")
os.makedirs(args.save_path)
config.save_path = args.save_path
config.test.to_minmax = args.minmax_results
with open(args.datasets_info, encoding="utf-8", mode="r") as f:
datasets_info = json.load(f)
te_paths = {}
for te_dataset in config.datasets.test.path:
if te_dataset not in datasets_info:
raise KeyError(f"{te_dataset} not in {args.datasets_info}!!!")
te_paths[te_dataset] = datasets_info[te_dataset]
config.datasets.test.path = te_paths
config.proj_root = os.path.dirname(os.path.abspath(__file__))
config.exp_name = misc.construct_exp_name(model_name=config.model_name, cfg=config)
return config
def test_once(
model,
data_loader,
save_path,
tta_setting,
clip_range=None,
show_bar=False,
desc="[TE]",
to_minmax=False,
):
model.is_training = False
cal_total_seg_metrics = recorder.CalTotalMetric()
pgr_bar = enumerate(data_loader)
if show_bar:
pgr_bar = tqdm(pgr_bar, total=len(data_loader), ncols=79, desc=desc)
for batch_id, batch in pgr_bar:
batch_images = misc.to_device(batch["data"], device=model.device)
if tta_setting.enable:
logits = pipeline.test_aug(
model=model, data=batch_images, strategy=tta_setting.strategy, reducation=tta_setting.reduction
)
else:
logits = model(data=batch_images)
probs = logits.sigmoid().squeeze(1).cpu().detach().numpy()
for i, pred in enumerate(probs):
mask_path = batch["info"]["mask_path"][i]
mask_array = io.read_gray_array(mask_path, dtype=np.uint8)
mask_h, mask_w = mask_array.shape
# here, sometimes, we can resize the prediciton to the shape of the mask's shape
pred = ops.imresize(pred, target_h=mask_h, target_w=mask_w, interp="linear")
if clip_range is not None:
pred = ops.clip_to_normalize(pred, clip_range=clip_range)
if to_minmax:
pred = ops.minmax(pred)
if save_path: # 这里的save_path包含了数据集名字
ops.save_array_as_image(data_array=pred, save_name=os.path.basename(mask_path), save_dir=save_path)
pred = (pred * 255).astype(np.uint8)
cal_total_seg_metrics.step(pred, mask_array, mask_path)
fixed_seg_results = cal_total_seg_metrics.get_results()
return fixed_seg_results
@torch.no_grad()
def testing(model, cfg):
pred_save_path = None
for data_name, data_path, loader in pipeline.get_te_loader(cfg):
if cfg.save_path:
pred_save_path = os.path.join(cfg.save_path, data_name)
print(f"Results will be saved into {pred_save_path}")
seg_results = test_once(
model=model,
save_path=pred_save_path,
data_loader=loader,
tta_setting=cfg.test.tta,
clip_range=cfg.test.clip_range,
show_bar=cfg.test.get("show_bar", False),
to_minmax=cfg.test.get("to_minmax", False),
)
print(f"Results on the testset({data_name}): {misc.mapping_to_str(data_path)}\n{seg_results}")
def main():
cfg = parse_config()
model, model_code = builder.build_obj_from_registry(
registry_name="MODELS", obj_name=cfg.model_name, return_code=True
)
io.load_weight(model=model, load_path=cfg.load_from)
model.device = "cuda:0"
model.to(model.device)
model.eval()
testing(model=model, cfg=cfg)
if __name__ == "__main__":
main()