forked from zhukaii/OS-TR
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlocalImageTest.py
72 lines (59 loc) · 2.39 KB
/
localImageTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import torch
import os
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import cv2
import time
# import onnx
# print(onnx.__version__)
ref_path_ = '/home/ros/OS_TR/test_images/ref_2.jpg'
query_path_ = '/home/ros/OS_TR/test_images/query_4.jpg'
# model_path_ = '/home/ros/OS_TR/log/dtd_dtd_weighted_bce_banded_0.001/snapshot-epoch_2021-11-25-16:42:11_texture.pth'
# model_path_ = '/home/ros/OS_TR/log/tcd_ResNet50_frozen_weighted_bce_LR_0.001/snapshot-epoch_2021-11-29-17:34:44_texture.pth'
model_path_ = '/home/ros/OS_TR/log/tcd_alot_ResNet50_frozen_weighted_bce_LR_0.001/snapshot-epoch_2021-12-01-08:58:47_texture.pth'
model = torch.load(model_path_)
model.eval()
ref_img = np.asarray(Image.open(ref_path_).convert('RGB').resize((256,256)))/255.0 #.convert('L').convert('RGB')
query_img = np.asarray(Image.open(query_path_).convert('RGB').resize((256,256)))/255.0
ref_tensor = (torch.from_numpy(ref_img).permute(2,0,1)).unsqueeze(0)
query_tensor = (torch.from_numpy(query_img).permute(2,0,1)).unsqueeze(0)
# ref_out = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
# query_out = cv2.cvtColor(query_img, cv2.COLOR_RGB2BGR)
# cv2.imwrite('/home/ros/OS_TR/ref_2.jpg', ref_out)
# cv2.imwrite('/home/ros/OS_TR/query_2.jpg', query_out)
transform_zk = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize((0.5355, 0.4852, 0.4441), std=(0.2667, 0.2588, 0.2667))
])
ref_tensor = transform_zk(ref_img).unsqueeze(0).float()
query_tensor = transform_zk(query_img).unsqueeze(0).float()
if torch.cuda.is_available():
query_tensor = query_tensor.cuda()
ref_tensor = ref_tensor.cuda()
t = time.time()
scores = model(query_tensor, ref_tensor)
elapsed = time.time() - t
print("inference time: "+str(elapsed))
# scores[scores >= 0.5] = 1
# scores[scores < 0.5] = 0
seg = scores[0, 0, :, :]#.long()
pred = seg.data.cpu().numpy()
fig = plt.figure(0)
ax = fig.add_subplot(1, 3, 1)
imgplot = plt.imshow(query_tensor[0].permute(1, 2, 0).data.cpu().numpy())
ax.set_title('Query')
ax.axis('off')
ax = fig.add_subplot(1, 3, 2)
imgplot = plt.imshow(ref_tensor[0].permute(1, 2, 0).data.cpu().numpy())
ax.set_title('Reference')
ax.axis('off')
ax = fig.add_subplot(1, 3, 3)
imgplot = plt.imshow(pred)
ax.set_title('Prediction')
ax.axis('off')
plt.figure(1)
plt.imshow(query_img)
plt.imshow(pred, alpha=0.5, cmap=plt.get_cmap("RdBu"))
plt.show()