-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_asteroid.py
122 lines (106 loc) · 4.18 KB
/
evaluate_asteroid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import argparse
import pickle
import numpy as np
import matplotlib.pyplot as plt
from graph import Graph
from model import GNet
from pool import FeaturePooling
from metrics import chamfer_loss, loss_function, f1_score
from data_asteroid import CustomDatasetFolder
# Args
parser = argparse.ArgumentParser(description='Pixel2Mesh evaluating script')
parser.add_argument('--data', type=str, metavar='D',
help="folder where data is located.")
parser.add_argument('--log_step', type=int, default=100, metavar='L',
help='how many batches to wait before logging evaluation status')
parser.add_argument('--output', type=str, default=None, metavar='G',
help='if not None, generate meshes to this folder')
parser.add_argument('--show_img', type=bool, default=False, metavar='S',
help='whether or not to show the images')
parser.add_argument('--load', type=str, metavar='M',
help='model file to load for evaluating.')
args = parser.parse_args()
# Model
nIms = 25
model_gcn = GNet(nIms)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(args.load, map_location=device)
model_gcn.load_state_dict(state_dict)
# Turn batch norm into eval mode
# for child in model_gcn.feat_extr.children():
# for ii in range(len(child)):
# if type(child[ii]) == torch.nn.BatchNorm2d:
# child[ii].track_running_stats = False
model_gcn.eval()
# Cuda
use_cuda = torch.cuda.is_available()
if use_cuda:
model_gcn.cuda()
print('Using GPU')
else:
print('Using CPU')
# Graph
graph = Graph("./ellipsoid/init_info.pickle")
# Data Loader
folder = CustomDatasetFolder(args.data, extensions = ["png"], dimension=nIms, print_ref=False)
val_loader = torch.utils.data.DataLoader(folder, batch_size=1, shuffle=True)
tot_loss_norm = 0
tot_loss_unorm = 0
tot_f1_1 = 0
tot_f1_2 = 0
tau = 1e-4
log_step = args.log_step
show_img = args.show_img
for n, data in enumerate(val_loader):
ims, gt_points, gt_normals = data
ims = np.transpose(ims, (1, 0, 2, 3, 4))
m, b, *x_dims = ims.shape
if use_cuda:
ims = ims.cuda()
gt_points = gt_points.cuda()
gt_normals = gt_normals.cuda()
# Show image
if show_img:
img = ims[0][0].float().numpy()
mean = np.array([0.485, 0.456, 0.406]).reshape((-1, 1, 1))
std = np.array([0.229, 0.224, 0.225]).reshape((-1, 1, 1))
img = ((img * std + mean).transpose(1, 2, 0) * 255.0).round().astype(int)
plt.imshow(img)
plt.show()
# Forward
graph.reset()
pools = []
for i in range(m):
pools.append(FeaturePooling(ims[i]))
pred_points = model_gcn(graph, pools)
# Compute eval metrics
_, loss_norm = chamfer_loss(pred_points[-1], gt_points.squeeze(), normalized=True)
_, loss_unorm = chamfer_loss(pred_points[-1], gt_points.squeeze(), normalized=False)
tot_loss_norm += loss_norm.item()
tot_loss_unorm += loss_unorm.item()
tot_f1_1 += f1_score(pred_points[-1], gt_points.squeeze(), threshold=tau)
tot_f1_2 += f1_score(pred_points[-1], gt_points.squeeze(), threshold=2*tau)
# Logs
if n%log_step == 0:
print("Batch", n)
print("Normalized Chamfer loss so far", tot_loss_norm/(n+1))
print("Unormalized Chamfer loss so far", tot_loss_unorm/(n+1))
print("F1 score (tau=1e-4)", tot_f1_1/(n+1))
print("F1 score (tau=2e-4)", tot_f1_2/(n+1))
# Generate meshes
if args.output is not None:
graph.vertices = pred_points[5]
graph.faces = graph.info[3][2]
graph.to_obj(args.output + "plane_pred_block3_"
+ str(n) + "_" + str(loss_norm.item()) + ".obj")
graph.vertices = gt_points[0, :, :]
graph.faces = []
graph.to_obj(args.output + "plane_gt"
+ str(n) + "_" + str(loss_norm.item()) + ".obj")
print("Mesh plane_pred" + str(n) + "_" + str(loss_norm.item()) + " generated")
print("Final Normalized Chamfer loss:", tot_loss_norm/(n+1))
print("Final Unormalized Chamfer loss:", tot_loss_norm/(n+1))
print("Final F1 score (tau=1e-4) :", tot_f1_1/(n+1))
print("Final F1 score (tau=2e-4) :", tot_f1_2/(n+1))