forked from chenxin-dlut/TransT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting.py
executable file
·155 lines (128 loc) · 5.1 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2
def draw_figure(fig):
fig.canvas.draw()
fig.canvas.flush_events()
plt.pause(0.001)
def show_tensor(a: torch.Tensor, fig_num = None, title = None, range=(None, None), ax=None):
"""Display a 2D tensor.
args:
fig_num: Figure number.
title: Title of figure.
"""
a_np = a.squeeze().cpu().clone().detach().numpy()
if a_np.ndim == 3:
a_np = np.transpose(a_np, (1, 2, 0))
if ax is None:
fig = plt.figure(fig_num)
plt.tight_layout()
plt.cla()
plt.imshow(a_np, vmin=range[0], vmax=range[1])
plt.axis('off')
plt.axis('equal')
if title is not None:
plt.title(title)
draw_figure(fig)
else:
ax.cla()
ax.imshow(a_np, vmin=range[0], vmax=range[1])
ax.set_axis_off()
ax.axis('equal')
if title is not None:
ax.set_title(title)
draw_figure(plt.gcf())
def plot_graph(a: torch.Tensor, fig_num = None, title = None):
"""Plot graph. Data is a 1D tensor.
args:
fig_num: Figure number.
title: Title of figure.
"""
a_np = a.squeeze().cpu().clone().detach().numpy()
if a_np.ndim > 1:
raise ValueError
fig = plt.figure(fig_num)
# plt.tight_layout()
plt.cla()
plt.plot(a_np)
if title is not None:
plt.title(title)
draw_figure(fig)
def show_image_with_boxes(im, boxes, iou_pred=None, disp_ids=None):
im_np = im.clone().cpu().squeeze().numpy()
im_np = np.ascontiguousarray(im_np.transpose(1, 2, 0).astype(np.uint8))
boxes = boxes.view(-1, 4).cpu().numpy().round().astype(int)
# Draw proposals
for i_ in range(boxes.shape[0]):
if disp_ids is None or disp_ids[i_]:
bb = boxes[i_, :]
disp_color = (i_*38 % 256, (255 - i_*97) % 256, (123 + i_*66) % 256)
cv2.rectangle(im_np, (bb[0], bb[1]), (bb[0] + bb[2], bb[1] + bb[3]),
disp_color, 1)
if iou_pred is not None:
text_pos = (bb[0], bb[1] - 5)
cv2.putText(im_np, 'ID={} IOU = {:3.2f}'.format(i_, iou_pred[i_]), text_pos,
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, bottomLeftOrigin=False)
im_tensor = torch.from_numpy(im_np.transpose(2, 0, 1)).float()
return im_tensor
def _pascal_color_map(N=256, normalized=False):
"""
Python implementation of the color map function for the PASCAL VOC data set.
Official Matlab version can be found in the PASCAL VOC devkit
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
"""
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7 - j)
g = g | (bitget(c, 1) << 7 - j)
b = b | (bitget(c, 2) << 7 - j)
c = c >> 3
cmap[i] = np.array([r, g, b])
cmap = cmap / 255 if normalized else cmap
return cmap
def overlay_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None):
""" Overlay mask over image.
Source: https://github.com/albertomontesg/davis-interactive/blob/master/davisinteractive/utils/visualization.py
This function allows you to overlay a mask over an image with some
transparency.
# Arguments
im: Numpy Array. Array with the image. The shape must be (H, W, 3) and
the pixels must be represented as `np.uint8` data type.
ann: Numpy Array. Array with the mask. The shape must be (H, W) and the
values must be intergers
alpha: Float. Proportion of alpha to apply at the overlaid mask.
colors: Numpy Array. Optional custom colormap. It must have shape (N, 3)
being N the maximum number of colors to represent.
contour_thickness: Integer. Thickness of each object index contour draw
over the overlay. This function requires to have installed the
package `opencv-python`.
# Returns
Numpy Array: Image of the overlay with shape (H, W, 3) and data type
`np.uint8`.
"""
im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int)
if im.shape[:-1] != ann.shape:
raise ValueError('First two dimensions of `im` and `ann` must match')
if im.shape[-1] != 3:
raise ValueError('im must have three channels at the 3 dimension')
colors = colors or _pascal_color_map()
colors = np.asarray(colors, dtype=np.uint8)
mask = colors[ann]
fg = im * alpha + (1 - alpha) * mask
img = im.copy()
img[ann > 0] = fg[ann > 0]
if contour_thickness: # pragma: no cover
import cv2
for obj_id in np.unique(ann[ann > 0]):
contours = cv2.findContours((ann == obj_id).astype(
np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(),
contour_thickness)
return img