-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathvisualize_datasets.py
96 lines (77 loc) · 3.17 KB
/
visualize_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import os
import sys
from torch.utils import data
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from datasets.coco import CocoDetection
from transforms import presets
from transforms import v2 as T
from util.collate_fn import collate_fn
from util.logger import setup_logger
from util.misc import fixed_generator, seed_worker
from util.visualize import visualize_coco_bounding_boxes
def parse_args():
parser = argparse.ArgumentParser(description="Visualize a datasets")
# dataset parameters
parser.add_argument("--coco-img", type=str, required=True)
parser.add_argument("--coco-ann", type=str, required=True)
parser.add_argument("--transform", type=str, default=None)
parser.add_argument("--workers", type=int, default=2)
# visualize parameters
parser.add_argument("--show-dir", type=str, default=None, required=True)
parser.add_argument("--show-conf", type=float, default=0.5)
# plot parameters
parser.add_argument("--font-scale", type=float, default=1.0)
parser.add_argument("--box-thick", type=int, default=1)
parser.add_argument("--fill-alpha", type=float, default=0.2)
parser.add_argument("--text-box-color", type=int, nargs="+", default=(255, 255, 255))
parser.add_argument("--text-font-color", type=int, nargs="+", default=None)
parser.add_argument("--text-alpha", type=float, default=1.0)
# engine parameters
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
return args
def visualize_datasets():
args = parse_args()
# setup logger
for logger_name in ["py.warnings", "accelerate", os.path.basename(os.getcwd())]:
setup_logger(name=logger_name)
# remove the ConvertDtype and Normalize for visualization
if args.transform:
transform = getattr(presets, args.transform)
transform = remove_cvtdtype_normalize(transform)
else:
transform = None
# plot annotations for each image
if args.show_dir:
dataset = CocoDetection(img_folder=args.coco_img, ann_file=args.coco_ann, transforms=transform)
data_loader = data.DataLoader(
dataset,
1,
shuffle=False,
num_workers=args.workers,
worker_init_fn=seed_worker,
generator=fixed_generator(),
collate_fn=collate_fn,
)
visualize_coco_bounding_boxes(
data_loader=data_loader,
show_conf=args.show_conf,
show_dir=args.show_dir,
font_scale=args.font_scale,
box_thick=args.box_thick,
fill_alpha=args.fill_alpha,
text_box_color=args.text_box_color,
text_font_color=args.text_font_color,
text_alpha=args.text_alpha,
)
def remove_cvtdtype_normalize(transform):
if isinstance(transform, T.Compose):
transform = [remove_cvtdtype_normalize(trans) for trans in transform.transforms]
transform = [trans for trans in transform if trans is not None]
return T.Compose(transform)
if isinstance(transform, (T.ConvertDtype, T.Normalize)):
return None
return transform
if __name__ == "__main__":
visualize_datasets()