-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexpectation_maximization_registration.py
86 lines (73 loc) · 3.2 KB
/
expectation_maximization_registration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
def initialize_sigma2(X, Y):
(N, D) = X.shape
(M, _) = Y.shape
XX = np.reshape(X, (1, N, D))
YY = np.reshape(Y, (M, 1, D))
XX = np.tile(XX, (M, 1, 1))
YY = np.tile(YY, (1, N, 1))
diff = XX - YY
err = np.multiply(diff, diff)
return np.sum(err) / (D * M * N)
class expectation_maximization_registration(object):
def __init__(self, X, Y, sigma2=None, max_iterations=100, tolerance=0.001, w=0, *args, **kwargs):
if type(X) is not np.ndarray or X.ndim != 2:
raise ValueError("The target point cloud (X) must be at a 2D numpy array.")
if type(Y) is not np.ndarray or Y.ndim != 2:
raise ValueError("The source point cloud (Y) must be a 2D numpy array.")
if X.shape[1] != Y.shape[1]:
raise ValueError("Both point clouds need to have the same number of dimensions.")
self.X = X
self.Y = Y
self.sigma2 = sigma2
(self.N, self.D) = self.X.shape
(self.M, _) = self.Y.shape
self.tolerance = tolerance
self.w = w
self.max_iterations = max_iterations
self.iteration = 0
self.err = self.tolerance + 1
self.P = np.zeros((self.M, self.N))
self.Pt1 = np.zeros((self.N, ))
self.P1 = np.zeros((self.M, ))
self.Np = 0
def register(self, callback=lambda **kwargs: None):
self.transform_point_cloud()
if self.sigma2 is None:
self.sigma2 = initialize_sigma2(self.X, self.TY)
self.q = -self.err - self.N * self.D/2 * np.log(self.sigma2)
while self.iteration < self.max_iterations and self.err > self.tolerance:
print('---- Coherent Point Drift iteration: {}, Aggregated Error: {}'.format(self.iteration, self.err))
self.iterate()
if callable(callback):
kwargs = {'iteration': self.iteration, 'error': self.err, 'X': self.X, 'Y': self.TY}
callback(**kwargs)
return self.TY, self.get_registration_parameters()
def get_registration_parameters(self):
raise NotImplementedError("Registration parameters should be defined in child classes.")
def iterate(self):
self.expectation()
self.maximization()
self.iteration += 1
def expectation(self):
P = np.zeros((self.M, self.N))
for i in range(0, self.M):
diff = self.X - np.tile(self.TY[i, :], (self.N, 1))
diff = np.multiply(diff, diff)
P[i, :] = P[i, :] + np.sum(diff, axis=1)
c = (2 * np.pi * self.sigma2) ** (self.D / 2)
c = c * self.w / (1 - self.w)
c = c * self.M / self.N
P = np.exp(-P / (2 * self.sigma2))
den = np.sum(P, axis=0)
den = np.tile(den, (self.M, 1))
den[den==0] = np.finfo(float).eps
den += c
self.P = np.divide(P, den)
self.Pt1 = np.sum(self.P, axis=0)
self.P1 = np.sum(self.P, axis=1)
self.Np = np.sum(self.P1)
def maximization(self):
self.update_transform()
self.transform_point_cloud()
self.update_variance()