forked from pytorch/ignite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvis.py
129 lines (98 loc) · 4.1 KB
/
vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import torch
from PIL import Image
try:
from image_dataset_viz import render_datapoint
except ImportError:
raise RuntimeError("Install it via pip install --upgrade git+https://github.com/vfdev-5/ImageDatasetViz.git")
def _getvocpallete(num_cls):
n = num_cls
pallete = [0] * (n * 3)
for j in range(0, n):
lab = j
pallete[j * 3 + 0] = 0
pallete[j * 3 + 1] = 0
pallete[j * 3 + 2] = 0
i = 0
while lab > 0:
pallete[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
pallete[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
pallete[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
i = i + 1
lab >>= 3
return pallete
vocpallete = _getvocpallete(256)
def render_mask(mask):
if isinstance(mask, np.ndarray):
mask = Image.fromarray(mask)
mask.putpalette(vocpallete)
mask = mask.convert(mode="RGB")
return mask
def tensor_to_rgb(t):
img = t.cpu().numpy().transpose((1, 2, 0))
return img.astype(np.uint8)
def make_grid(batch_img, batch_mask, img_denormalize_fn, batch_gt_mask=None):
"""Create a grid from batch image and mask as
img1 | img2 | img3 | img4 | ...
i+m1 | i+m2 | i+m3 | i+m4 | ...
mask1 | mask2 | mask3 | mask4 | ...
i+M1 | i+M2 | i+M3 | i+M4 | ...
Mask1 | Mask2 | Mask3 | Mask4 | ...
i+m = image + mask blended with alpha=0.4
- maskN is predicted mask
- MaskN is ground-truth mask if given
Args:
batch_img (torch.Tensor) batch of images of any type
batch_mask (torch.Tensor) batch of masks
img_denormalize_fn (Callable): function to denormalize batch of images
batch_gt_mask (torch.Tensor, optional): batch of ground truth masks.
"""
assert isinstance(batch_img, torch.Tensor) and isinstance(batch_mask, torch.Tensor)
assert len(batch_img) == len(batch_mask)
if batch_gt_mask is not None:
assert isinstance(batch_gt_mask, torch.Tensor)
assert len(batch_mask) == len(batch_gt_mask)
b = batch_img.shape[0]
h, w = batch_img.shape[2:]
le = 3 if batch_gt_mask is None else 3 + 2
out_image = np.zeros((h * le, w * b, 3), dtype="uint8")
for i in range(b):
img = batch_img[i]
mask = batch_mask[i]
img = img_denormalize_fn(img)
img = tensor_to_rgb(img)
mask = mask.cpu().numpy()
mask = render_mask(mask)
out_image[0:h, i * w : (i + 1) * w, :] = img
out_image[1 * h : 2 * h, i * w : (i + 1) * w, :] = render_datapoint(img, mask, blend_alpha=0.4)
out_image[2 * h : 3 * h, i * w : (i + 1) * w, :] = mask
if batch_gt_mask is not None:
gt_mask = batch_gt_mask[i]
gt_mask = gt_mask.cpu().numpy()
gt_mask = render_mask(gt_mask)
out_image[3 * h : 4 * h, i * w : (i + 1) * w, :] = render_datapoint(img, gt_mask, blend_alpha=0.4)
out_image[4 * h : 5 * h, i * w : (i + 1) * w, :] = gt_mask
return out_image
def predictions_gt_images_handler(img_denormalize_fn, n_images=None, another_engine=None, prefix_tag=None):
def wrapper(engine, logger, event_name):
batch = engine.state.batch
output = engine.state.output
x = batch["image"]
y = batch["mask"]
y_pred = output[0]
if y.shape == y_pred.shape and y.ndim == 4:
# Case of y of shape (B, C, H, W)
y = torch.argmax(y, dim=1)
y_pred = torch.argmax(y_pred, dim=1).byte()
if n_images is not None:
x = x[:n_images, ...]
y = y[:n_images, ...]
y_pred = y_pred[:n_images, ...]
grid_pred_gt = make_grid(x, y_pred, img_denormalize_fn, batch_gt_mask=y)
state = engine.state if another_engine is None else another_engine.state
global_step = state.get_event_attrib_value(event_name)
tag = "predictions_with_gt"
if prefix_tag is not None:
tag = f"{prefix_tag}: {tag}"
logger.writer.add_image(tag=tag, img_tensor=grid_pred_gt, global_step=global_step, dataformats="HWC")
return wrapper