-
Notifications
You must be signed in to change notification settings - Fork 6
/
test_funcs_combined.py
executable file
·72 lines (67 loc) · 3.02 KB
/
test_funcs_combined.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from re import template
import torch
import copy
from tqdm import tqdm
import numpy as np
import time
import pyrender
import cv2
from scipy.spatial.transform import Rotation as R
import trimesh
def test_autoencoder_dataloader(device, model, dataloader_test, shapedata, mm_constant=1000):
model.eval()
l1_loss = 0
l2_loss = 0
shapedata_mean = torch.Tensor(shapedata.mean).to(device)
shapedata_std = torch.Tensor(shapedata.std).to(device)
template = shapedata.reference_mesh
with torch.no_grad():
start_time = time.time()
for i, sample_dict in enumerate(tqdm(dataloader_test)):
tx = sample_dict['points'].to(device)
prediction = model(tx)
if i==0:
predictions = copy.deepcopy(prediction)
else:
predictions = torch.cat([predictions,prediction],0)
if dataloader_test.dataset.dummy_node:
x_recon = prediction[:,:-1]
x = tx[:,:-1]
else:
x_recon = prediction
x = tx
l1_loss+= torch.mean(torch.abs(x_recon-x))*x.shape[0]/float(len(dataloader_test.dataset))
x_recon = (x_recon * shapedata_std + shapedata_mean)
for j in range(x_recon.shape[0]):
template.vertices = x_recon[j].cpu().numpy()
rotate=trimesh.transformations.rotation_matrix(
angle=np.radians(-90.0),
direction=[0,1,0],
point=[0,0,0])
template.vertices =template.vertices @ rotate[:3, :3]
template.export('meshes/'+str(i).zfill(3)+str(j).zfill(3)+'.ply','ply')
mesh = pyrender.Mesh.from_trimesh(template)
scene = pyrender.Scene()
scene.add(mesh)
camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0)
s = 1 # np.sqrt(2)/2
camera_pose = np.array([
[1.0, -s, s, 0.15],
[0.0, 1.0, 0.0, -0.1],
[0.0, s, 1., 0.25],
[0.0, 0.0, 0.0, 1.0],
])
scene.add(camera, pose=camera_pose)
light = pyrender.DirectionalLight(color=np.ones(3), intensity=3.0)
scene.add(light, pose=camera_pose)
r = pyrender.OffscreenRenderer(400, 400)
color, depth = r.render(scene)
cv2.imwrite('meshes/'+str(i).zfill(3)+str(j).zfill(3)+'.png', color)
# pyrender.Viewer(scene, use_raymond_lighting=True)
# x = (x * shapedata_std + shapedata_mean) * mm_constant
# l2_loss+= torch.mean(torch.sqrt(torch.sum((x_recon - x)**2,dim=2)))*x.shape[0]/float(len(dataloader_test.dataset))
print("--- %s seconds ---" % (time.time() - start_time))
predictions = predictions.cpu()
# l1_loss = l1_loss.item()
# l2_loss = l2_loss.item()
return predictions, l1_loss.cpu(), l2_loss.cpu()