-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisualize.py
110 lines (88 loc) · 3.81 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
__all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
# functions to show an image
def make_image(img, mean=(0, 0, 0), std=(1, 1, 1)):
for i in range(0, 3):
img[i] = img[i] * std[i] + mean[i] # unnormalize
npimg = img.numpy()
return np.transpose(npimg, (1, 2, 0))
def gauss(x, a, b, c):
return torch.exp(-torch.pow(torch.add(x, -b), 2).div(2 * c * c)).mul(a)
def colorize(x):
''' Converts a one-channel grayscale image to a color heatmap image '''
if x.dim() == 2:
torch.unsqueeze(x, 0, out=x)
if x.dim() == 3:
cl = torch.zeros([3, x.size(1), x.size(2)])
cl[0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
cl[1] = gauss(x, 1, .5, .3)
cl[2] = gauss(x, 1, .2, .3)
cl[cl.gt(1)] = 1
elif x.dim() == 4:
cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
cl[:, 0, :, :] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
cl[:, 1, :, :] = gauss(x, 1, .5, .3)
cl[:, 2, :, :] = gauss(x, 1, .2, .3)
return cl
def show_batch(images, Mean=(2, 2, 2), Std=(0.5, 0.5, 0.5)):
images = make_image(torchvision.utils.make_grid(images), Mean, Std)
plt.imshow(images)
plt.show()
def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5, 0.5, 0.5)):
im_size = images.size(2)
# save for adding mask
im_data = images.clone()
for i in range(0, 3):
im_data[:, i, :, :] = im_data[:, i, :, :] * Std[i] + Mean[i] # unnormalize
images = make_image(torchvision.utils.make_grid(images), Mean, Std)
plt.subplot(2, 1, 1)
plt.imshow(images)
plt.axis('off')
# for b in range(mask.size(0)):
# mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
mask_size = mask.size(2)
# print('Max %f Min %f' % (mask.max(), mask.min()))
mask = (upsampling(mask, scale_factor=im_size / mask_size))
# mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
# for c in range(3):
# mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
# print(mask.size())
mask = make_image(torchvision.utils.make_grid(0.3 * im_data + 0.7 * mask.expand_as(im_data)))
# mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
plt.subplot(2, 1, 2)
plt.imshow(mask)
plt.axis('off')
def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5, 0.5, 0.5)):
im_size = images.size(2)
# save for adding mask
im_data = images.clone()
for i in range(0, 3):
im_data[:, i, :, :] = im_data[:, i, :, :] * Std[i] + Mean[i] # unnormalize
images = make_image(torchvision.utils.make_grid(images), Mean, Std)
plt.subplot(1 + len(masklist), 1, 1)
plt.imshow(images)
plt.axis('off')
for i in range(len(masklist)):
mask = masklist[i].data.cpu()
# for b in range(mask.size(0)):
# mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
mask_size = mask.size(2)
# print('Max %f Min %f' % (mask.max(), mask.min()))
mask = (upsampling(mask, scale_factor=im_size / mask_size))
# mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
# for c in range(3):
# mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
# print(mask.size())
mask = make_image(torchvision.utils.make_grid(0.3 * im_data + 0.7 * mask.expand_as(im_data)))
# mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
plt.subplot(1 + len(masklist), 1, i + 2)
plt.imshow(mask)
plt.axis('off')
# x = torch.zeros(1, 3, 3)
# out = colorize(x)
# out_im = make_image(out)
# plt.imshow(out_im)
# plt.show()