-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshrec_dataset.py
270 lines (235 loc) · 10 KB
/
shrec_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import os
from pathlib import Path
import numpy as np
import potpourri3d as pp3d
import torch
from torch.utils.data import Dataset
import diffusion_net as dfn
from utils import auto_WKS, farthest_point_sample, square_distance
#
from tqdm import tqdm
from itertools import permutations
import Tools.mesh as qm
from Tools.utils import op_cpl
class ShrecDataset(Dataset):
"""
Implementation of shape matching Dataset !WITH GroundTruth! correspondence files (between given pairs).
It is called Shrec Dataset because historically, SHREC'19 used that system.
Any dataset using gt files falls into this category and can therefore be utilized via this class.
---Parameters:
@ root_dir: root folder containing shapes_train and shapes_test folder
@ name: name of the dataset. ex: scape-remeshed, or scape-anisotropic
@ k_eig: number of Laplace-Beltrami eigenvectors loaded
@ n_fmap: number of eigenvectors used for fmap computation
@ n_cfmap: number of complex eigenvectors used for complex fmap computation
@ with_wks: None if no WKS (C_in <= 3), else the number of WKS descriptors
@ use_cache: cache for storing dataset (True by default)
@ op_cache_dir: cache for diffusion net operators (from config['dataset']['cache_dir'])
@ train: for train or test set
---At initialisation, loads:
1) verts, faces and ground-truths
2) geometric operators (Laplacian, Gradient)
3) (optional if C_in = 3) WKS descriptors (for best setting)
4) (optional if n_cfmap = 0) complex operators (for orientation-aware unsupervised learning)
---When delivering an element of the dataset, yields a dictionary with:
1) shape1 containing all necessary info for source shape
2) shape2 containing all necessary info for target shape
3) ground-truth functional map Cgt (obtained with gt files)
"""
def __init__(self, root_dir, name="remeshed",
k_eig=128, n_fmap=30, n_cfmap=20,
with_wks=None,
use_cache=True, op_cache_dir=None,
train=False):
self.k_eig = k_eig
self.n_fmap = n_fmap
self.n_cfmap = n_cfmap
self.root_dir = root_dir
self.cache_dir = root_dir
self.op_cache_dir = op_cache_dir
# check the cache
split = "train" if train else "test"
wks_suf = "" if with_wks is None else "wks_"
if use_cache:
load_cache = os.path.join(self.cache_dir, f"cache_{name}_{wks_suf}{split}.pt")
print("using dataset cache path: " + str(load_cache))
if os.path.exists(load_cache):
print(" --> loading dataset from cache")
(
# main
self.verts_list,
self.faces_list,
self.frames_list,
# diffNet
self.massvec_list,
self.L_list,
self.evals_list,
self.evecs_list,
self.gradX_list,
self.gradY_list,
# Q-Maps
self.cevecs_list,
self.cevals_list,
self.spec_grad_list,
# misc
self.used_shapes,
self.gt_list
) = torch.load(load_cache)
self.combinations = np.loadtxt(os.path.join(Path(root_dir), 'test_pairs.txt'), delimiter=',').astype(int)
return
print(" --> dataset not in cache, repopulating")
# Load the meshes
# define files and order
shapes_split = "shapes_" + split
self.used_shapes = sorted([x.stem for x in (Path(root_dir) / shapes_split).iterdir() if 'DS_' not in x.stem])
# set combinations
self.combinations = np.loadtxt(os.path.join(Path(root_dir), 'test_pairs.txt'), delimiter=',').astype(int)
#
mesh_dirpath = Path(root_dir) / shapes_split
gt_dirpath = Path(root_dir) / "groundtruth"
# Get all the files
self.verts_list = []
self.faces_list = []
self.vts_list = []
# Load the actual files
for shape_name in tqdm(self.used_shapes):
#print("loading mesh " + str(shape_name))
verts, faces = pp3d.read_mesh(str(mesh_dirpath / f"{shape_name}.off"))
# to torch
verts = torch.tensor(np.ascontiguousarray(verts)).float()
faces = torch.tensor(np.ascontiguousarray(faces))
self.verts_list.append(verts)
self.faces_list.append(faces)
# Load ground-truths
self.gt_list = {}
for i, j in tqdm(self.combinations):
gt_map = np.loadtxt(str(gt_dirpath / f"{i}_{j}.map"), dtype=np.int32).astype(int)
self.gt_list[(i, j)] = gt_map
# Precompute operators
(
self.frames_list,
self.massvec_list,
self.L_list,
self.evals_list,
self.evecs_list,
self.gradX_list,
self.gradY_list,
) = dfn.geometry.get_all_operators(
self.verts_list,
self.faces_list,
k_eig=self.k_eig,
op_cache_dir=self.op_cache_dir,
)
# Compute wks descriptors if required (and replace vertices field with it)
if with_wks is not None:
print("compute WKS descriptors")
for i in tqdm(range(len(self.used_shapes))):
self.verts_list[i] = auto_WKS(self.evals_list[i], self.evecs_list[i], with_wks).float()
# Now we also need to get the complex Laplacian and the spectral gradients
print("loading operators for Q...")
self.cevecs_list = []
self.cevals_list = []
self.spec_grad_list = []
for shape_name in tqdm(self.used_shapes):
# case where computing complex spectral is not possible (non manifoldness, borders, point cloud, ...)
if n_cfmap == 0:
self.cevecs_list += [None]
self.cevals_list += [None]
self.spec_grad_list += [None]
continue
# else load mesh and compute complex laplacian and gradient operators
mesh_for_Q = qm.mesh(str(mesh_dirpath / f"{shape_name}.off"),
spectral=0, complex_spectral=n_cfmap, spectral_folder=root_dir)
#
mesh_for_Q.grad_vert_op()
mesh_for_Q.grad_vc = op_cpl(mesh_for_Q.gradv.T).T
self.cevecs_list += [mesh_for_Q.ceig]
self.cevals_list += [mesh_for_Q.cvals]
self.spec_grad_list += [np.linalg.pinv(mesh_for_Q.ceig) @ mesh_for_Q.grad_vc]
print('done')
# save to cache
if use_cache:
dfn.utils.ensure_dir_exists(self.cache_dir)
torch.save(
(
self.verts_list,
self.faces_list,
self.frames_list,
#
self.massvec_list,
self.L_list,
self.evals_list,
self.evecs_list,
self.gradX_list,
self.gradY_list,
#
self.cevecs_list,
self.cevals_list,
self.spec_grad_list,
#
self.used_shapes,
self.gt_list
),
load_cache,
)
def __len__(self):
return len(self.combinations)
def __getitem__(self, idx):
idx1, idx2 = self.combinations[idx]
shape1 = {
"xyz": self.verts_list[idx1],
"faces": self.faces_list[idx1],
"frames": self.frames_list[idx1],
#
"mass": self.massvec_list[idx1],
"L": self.L_list[idx1],
"evals": self.evals_list[idx1],
"evecs": self.evecs_list[idx1],
"gradX": self.gradX_list[idx1],
"gradY": self.gradY_list[idx1],
#
"cevecs": self.cevecs_list[idx1],
"cevals": self.cevals_list[idx1],
"spec_grad": self.spec_grad_list[idx1],
#
# "vts": self.vts_list[idx1],
"name": self.used_shapes[idx1],
}
shape2 = {
"xyz": self.verts_list[idx2],
"faces": self.faces_list[idx2],
"frames": self.frames_list[idx2],
#
"mass": self.massvec_list[idx2],
"L": self.L_list[idx2],
"evals": self.evals_list[idx2],
"evecs": self.evecs_list[idx2],
"gradX": self.gradX_list[idx2],
"gradY": self.gradY_list[idx2],
#
"cevecs": self.cevecs_list[idx2],
"cevals": self.cevals_list[idx2],
"spec_grad": self.spec_grad_list[idx2],
#
# "vts": self.vts_list[idx2],
"name": self.used_shapes[idx2],
}
# Compute fmap
evec_1, evec_2 = shape1["evecs"][:, :self.n_fmap], shape2["evecs"][:, :self.n_fmap]
gt = self.gt_list[(idx1, idx2)]
shape1['gt'] = gt # add it on shape 1 for eval
C_gt = torch.pinverse(evec_2[gt]) @ evec_1
# C_gt = torch.eye(self.n_fmap) # if we don't want to compute the map at all
return {"shape1": shape1, "shape2": shape2, "gt": gt, "C_gt": C_gt}
def shape_to_device(dict_shape, device):
names_to_device = ["xyz", "faces", "mass", "evals", "evecs", "gradX", "gradY",
"cevecs", "cevals", "spec_grad"]
for k, v in dict_shape.items():
if "shape" in k:
for name in names_to_device:
if v[name] is not None:
v[name] = v[name].to(device)
dict_shape[k] = v
else:
dict_shape[k] = v.to(device)
return dict_shape