From e5b1d6d3a3614749d2ddcae9b42e50869c8266a2 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Thu, 2 Apr 2020 14:44:43 -0700 Subject: [PATCH] Umeyama MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Umeyama estimates a rigid motion between two sets of corresponding points. Benchmark output for `bm_points_alignment` ``` Arguments key: [_____] Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- CorrespodingPointsAlignment_True_1_3_True_100_False 7382 9833 68 CorrespodingPointsAlignment_True_1_3_True_10000_False 8183 10500 62 CorrespodingPointsAlignment_True_1_3_False_100_False 7301 9263 69 CorrespodingPointsAlignment_True_1_3_False_10000_False 7945 9746 64 CorrespodingPointsAlignment_True_1_20_True_100_False 13706 41623 37 CorrespodingPointsAlignment_True_1_20_True_10000_False 11044 33766 46 CorrespodingPointsAlignment_True_1_20_False_100_False 9908 28791 51 CorrespodingPointsAlignment_True_1_20_False_10000_False 9523 18680 53 CorrespodingPointsAlignment_True_10_3_True_100_False 29585 32026 17 CorrespodingPointsAlignment_True_10_3_True_10000_False 29626 36324 18 CorrespodingPointsAlignment_True_10_3_False_100_False 26013 29253 20 CorrespodingPointsAlignment_True_10_3_False_10000_False 25000 33820 20 CorrespodingPointsAlignment_True_10_20_True_100_False 40955 41592 13 CorrespodingPointsAlignment_True_10_20_True_10000_False 42087 42393 12 CorrespodingPointsAlignment_True_10_20_False_100_False 39863 40381 13 CorrespodingPointsAlignment_True_10_20_False_10000_False 40813 41699 13 CorrespodingPointsAlignment_True_100_3_True_100_False 183146 194745 3 CorrespodingPointsAlignment_True_100_3_True_10000_False 213789 231466 3 CorrespodingPointsAlignment_True_100_3_False_100_False 177805 180796 3 CorrespodingPointsAlignment_True_100_3_False_10000_False 184963 185695 3 CorrespodingPointsAlignment_True_100_20_True_100_False 347181 347325 2 CorrespodingPointsAlignment_True_100_20_True_10000_False 363259 363613 2 CorrespodingPointsAlignment_True_100_20_False_100_False 351769 352496 2 CorrespodingPointsAlignment_True_100_20_False_10000_False 375629 379818 2 CorrespodingPointsAlignment_False_1_3_True_100_False 11155 13770 45 CorrespodingPointsAlignment_False_1_3_True_10000_False 10743 13938 47 CorrespodingPointsAlignment_False_1_3_False_100_False 9578 11511 53 CorrespodingPointsAlignment_False_1_3_False_10000_False 9549 11984 53 CorrespodingPointsAlignment_False_1_20_True_100_False 13809 14183 37 CorrespodingPointsAlignment_False_1_20_True_10000_False 14084 15082 36 CorrespodingPointsAlignment_False_1_20_False_100_False 12765 14177 40 CorrespodingPointsAlignment_False_1_20_False_10000_False 12811 13096 40 CorrespodingPointsAlignment_False_10_3_True_100_False 28823 39384 18 CorrespodingPointsAlignment_False_10_3_True_10000_False 27135 27525 19 CorrespodingPointsAlignment_False_10_3_False_100_False 26236 28980 20 CorrespodingPointsAlignment_False_10_3_False_10000_False 42324 45123 12 CorrespodingPointsAlignment_False_10_20_True_100_False 723902 723902 1 CorrespodingPointsAlignment_False_10_20_True_10000_False 220007 252886 3 CorrespodingPointsAlignment_False_10_20_False_100_False 55593 71636 9 CorrespodingPointsAlignment_False_10_20_False_10000_False 44419 71861 12 CorrespodingPointsAlignment_False_100_3_True_100_False 184768 185199 3 CorrespodingPointsAlignment_False_100_3_True_10000_False 198657 213868 3 CorrespodingPointsAlignment_False_100_3_False_100_False 224598 309645 3 CorrespodingPointsAlignment_False_100_3_False_10000_False 197863 202002 3 CorrespodingPointsAlignment_False_100_20_True_100_False 293484 309459 2 CorrespodingPointsAlignment_False_100_20_True_10000_False 327253 366644 2 CorrespodingPointsAlignment_False_100_20_False_100_False 420793 422194 2 CorrespodingPointsAlignment_False_100_20_False_10000_False 462634 485542 2 CorrespodingPointsAlignment_True_1_3_True_100_True 7664 9909 66 CorrespodingPointsAlignment_True_1_3_True_10000_True 7190 8366 70 CorrespodingPointsAlignment_True_1_3_False_100_True 6549 8316 77 CorrespodingPointsAlignment_True_1_3_False_10000_True 6534 7710 77 CorrespodingPointsAlignment_True_10_3_True_100_True 29052 32940 18 CorrespodingPointsAlignment_True_10_3_True_10000_True 30526 33453 17 CorrespodingPointsAlignment_True_10_3_False_100_True 28708 32993 18 CorrespodingPointsAlignment_True_10_3_False_10000_True 30630 35973 17 CorrespodingPointsAlignment_True_100_3_True_100_True 264909 320820 3 CorrespodingPointsAlignment_True_100_3_True_10000_True 310902 322604 2 CorrespodingPointsAlignment_True_100_3_False_100_True 246832 250634 3 CorrespodingPointsAlignment_True_100_3_False_10000_True 276006 289061 2 CorrespodingPointsAlignment_False_1_3_True_100_True 11421 13757 44 CorrespodingPointsAlignment_False_1_3_True_10000_True 11199 12532 45 CorrespodingPointsAlignment_False_1_3_False_100_True 11474 15841 44 CorrespodingPointsAlignment_False_1_3_False_10000_True 10384 13188 49 CorrespodingPointsAlignment_False_10_3_True_100_True 36599 47340 14 CorrespodingPointsAlignment_False_10_3_True_10000_True 40702 50754 13 CorrespodingPointsAlignment_False_10_3_False_100_True 41277 52149 13 CorrespodingPointsAlignment_False_10_3_False_10000_True 34286 37091 15 CorrespodingPointsAlignment_False_100_3_True_100_True 254991 258578 2 CorrespodingPointsAlignment_False_100_3_True_10000_True 257999 261285 2 CorrespodingPointsAlignment_False_100_3_False_100_True 247511 248693 3 CorrespodingPointsAlignment_False_100_3_False_10000_True 251807 263865 3 ``` Reviewed By: gkioxari Differential Revision: D19808389 fbshipit-source-id: 83305a58627d2fc5dcaf3c3015132d8148f28c29 --- pytorch3d/ops/__init__.py | 1 + pytorch3d/ops/points_alignment.py | 151 +++++++++++++ tests/bm_points_alignment.py | 40 ++++ tests/test_points_alignment.py | 358 ++++++++++++++++++++++++++++++ 4 files changed, 550 insertions(+) create mode 100644 pytorch3d/ops/points_alignment.py create mode 100644 tests/bm_points_alignment.py create mode 100644 tests/test_points_alignment.py diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index de2eb61ba..7d84f6d65 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -6,6 +6,7 @@ from .mesh_face_areas_normals import mesh_face_areas_normals from .nearest_neighbor_points import nn_points_idx from .packed_to_padded import packed_to_padded, padded_to_packed +from .points_alignment import corresponding_points_alignment from .sample_points_from_meshes import sample_points_from_meshes from .subdivide_meshes import SubdivideMeshes from .vert_align import vert_align diff --git a/pytorch3d/ops/points_alignment.py b/pytorch3d/ops/points_alignment.py new file mode 100644 index 000000000..60d0afd0d --- /dev/null +++ b/pytorch3d/ops/points_alignment.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import warnings +from typing import Tuple, Union +import torch + +from pytorch3d.structures.pointclouds import Pointclouds + + +def corresponding_points_alignment( + X: Union[torch.Tensor, Pointclouds], + Y: Union[torch.Tensor, Pointclouds], + estimate_scale: bool = False, + allow_reflection: bool = False, + eps: float = 1e-8, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Finds a similarity transformation (rotation `R`, translation `T` + and optionally scale `s`) between two given sets of corresponding + `d`-dimensional points `X` and `Y` such that: + + `s[i] X[i] R[i] + T[i] = Y[i]`, + + for all batch indexes `i` in the least squares sense. + + The algorithm is also known as Umeyama [1]. + + Args: + X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` + or a `Pointclouds` object. + Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` + or a `Pointclouds` object. + estimate_scale: If `True`, also estimates a scaling component `s` + of the transformation. Otherwise assumes an identity + scale and returns a tensor of ones. + allow_reflection: If `True`, allows the algorithm to return `R` + which is orthonormal but has determinant==-1. + eps: A scalar for clamping to avoid dividing by zero. Active for the + code that estimates the output scale `s`. + + Returns: + 3-element tuple containing + - **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`. + - **T**: Batch of translations of shape `(minibatch, d)`. + - **s**: batch of scaling factors of shape `(minibatch, )`. + + References: + [1] Shinji Umeyama: Least-Suqares Estimation of + Transformation Parameters Between Two Point Patterns + """ + + # make sure we convert input Pointclouds structures to tensors + Xt, num_points = _convert_point_cloud_to_tensor(X) + Yt, num_points_Y = _convert_point_cloud_to_tensor(Y) + + if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any(): + raise ValueError( + "Point sets X and Y have to have the same \ + number of batches, points and dimensions." + ) + + b, n, dim = Xt.shape + + # compute the centroids of the point sets + Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1) + Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1) + + # mean-center the point sets + Xc = Xt - Xmu[:, None] + Yc = Yt - Ymu[:, None] + + if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any(): + # in case we got Pointclouds as input, mask the unused entries in Xc, Yc + mask = ( + torch.arange(n, dtype=torch.int64, device=Xc.device)[None] + < num_points[:, None] + ).type_as(Xc) + Xc *= mask[:, :, None] + Yc *= mask[:, :, None] + + if (num_points < (dim + 1)).any(): + warnings.warn( + "The size of one of the point clouds is <= dim+1. " + + "corresponding_points_alignment can't return a unique solution." + ) + + # compute the covariance XYcov between the point sets Xc, Yc + XYcov = torch.bmm(Xc.transpose(2, 1), Yc) + XYcov = XYcov / torch.clamp(num_points[:, None, None], 1) + + # decompose the covariance matrix XYcov + U, S, V = torch.svd(XYcov) + + # identity matrix used for fixing reflections + E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat( + b, 1, 1 + ) + + if not allow_reflection: + # reflection test: + # checks whether the estimated rotation has det==1, + # if not, finds the nearest rotation s.t. det==1 by + # flipping the sign of the last singular vector U + R_test = torch.bmm(U, V.transpose(2, 1)) + E[:, -1, -1] = torch.det(R_test) + + # find the rotation matrix by composing U and V again + R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1)) + + if estimate_scale: + # estimate the scaling component of the transformation + trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1) + Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1) + + # the scaling component + s = trace_ES / torch.clamp(Xcov, eps) + + # translation component + T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :] + + else: + # translation component + T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0] + + # unit scaling since we do not estimate scale + s = T.new_ones(b) + + return R, T, s + + +def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]): + """ + If `type(pcl)==Pointclouds`, converts a `pcl` object to a + padded representation and returns it together with the number of points + per batch. Otherwise, returns the input itself with the number of points + set to the size of the second dimension of `pcl`. + """ + if isinstance(pcl, Pointclouds): + X = pcl.points_padded() + num_points = pcl.num_points_per_cloud() + elif torch.is_tensor(pcl): + X = pcl + num_points = X.shape[1] * torch.ones( + X.shape[0], device=X.device, dtype=torch.int64 + ) + else: + raise ValueError( + "The inputs X, Y should be either Pointclouds objects or tensors." + ) + return X, num_points diff --git a/tests/bm_points_alignment.py b/tests/bm_points_alignment.py new file mode 100644 index 000000000..f823b0cca --- /dev/null +++ b/tests/bm_points_alignment.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from copy import deepcopy +from itertools import product +from fvcore.common.benchmark import benchmark + +from test_points_alignment import TestCorrespondingPointsAlignment + + +def bm_corresponding_points_alignment() -> None: + + case_grid = { + "allow_reflection": [True, False], + "batch_size": [1, 10, 100], + "dim": [3, 20], + "estimate_scale": [True, False], + "n_points": [100, 10000], + "use_pointclouds": [False], + } + + test_args = sorted(case_grid.keys()) + test_cases = product(*[case_grid[k] for k in test_args]) + kwargs_list = [dict(zip(test_args, case)) for case in test_cases] + + # add the use_pointclouds=True test cases whenever we have dim==3 + kwargs_to_add = [] + for entry in kwargs_list: + if entry["dim"] == 3: + entry_add = deepcopy(entry) + entry_add["use_pointclouds"] = True + kwargs_to_add.append(entry_add) + kwargs_list.extend(kwargs_to_add) + + benchmark( + TestCorrespondingPointsAlignment.corresponding_points_alignment, + "CorrespodingPointsAlignment", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/test_points_alignment.py b/tests/test_points_alignment.py new file mode 100644 index 000000000..fc2a6d9f9 --- /dev/null +++ b/tests/test_points_alignment.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import numpy as np +import unittest +import torch + +from pytorch3d.ops import points_alignment +from pytorch3d.structures.pointclouds import Pointclouds +from pytorch3d.transforms import rotation_conversions + + +def _apply_pcl_transformation(X, R, T, s=None): + """ + Apply a batch of similarity/rigid transformations, parametrized with + rotation `R`, translation `T` and scale `s`, to an input batch of + point clouds `X`. + """ + if isinstance(X, Pointclouds): + num_points = X.num_points_per_cloud() + X_t = X.points_padded() + else: + X_t = X + + if s is not None: + X_t = s[:, None, None] * X_t + + X_t = torch.bmm(X_t, R) + T[:, None, :] + + if isinstance(X, Pointclouds): + X_list = [x[:n_p] for x, n_p in zip(X_t, num_points)] + X_t = Pointclouds(X_list) + + return X_t + + +class TestCorrespondingPointsAlignment(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + @staticmethod + def random_rotation(batch_size, dim, device=None): + """ + Generates a batch of random `dim`-dimensional rotation matrices. + """ + if dim == 3: + R = rotation_conversions.random_rotations(batch_size, device=device) + else: + # generate random rotation matrices with orthogonalization of + # random normal square matrices, followed by a transformation + # that ensures determinant(R)==1 + H = torch.randn( + batch_size, dim, dim, dtype=torch.float32, device=device + ) + U, _, V = torch.svd(H) + E = torch.eye(dim, dtype=torch.float32, device=device)[None].repeat( + batch_size, 1, 1 + ) + E[:, -1, -1] = torch.det(torch.bmm(U, V.transpose(2, 1))) + R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1)) + assert torch.allclose( + torch.det(R), R.new_ones(batch_size), atol=1e-4 + ) + + return R + + @staticmethod + def init_point_cloud( + batch_size=10, + n_points=1000, + dim=3, + device=None, + use_pointclouds=False, + random_pcl_size=True, + ): + """ + Generate a batch of normally distributed point clouds. + """ + if use_pointclouds: + assert dim == 3, "Pointclouds support only 3-dim points." + # generate a `batch_size` point clouds with number of points + # between 4 and `n_points` + if random_pcl_size: + n_points_per_batch = torch.randint( + low=4, + high=n_points, + size=(batch_size,), + device=device, + dtype=torch.int64, + ) + X_list = [ + torch.randn( + int(n_pt), dim, device=device, dtype=torch.float32 + ) + for n_pt in n_points_per_batch + ] + X = Pointclouds(X_list) + else: + X = torch.randn( + batch_size, + n_points, + dim, + device=device, + dtype=torch.float32, + ) + X = Pointclouds(list(X)) + else: + X = torch.randn( + batch_size, n_points, dim, device=device, dtype=torch.float32 + ) + return X + + @staticmethod + def generate_pcl_transformation( + batch_size=10, scale=False, reflect=False, dim=3, device=None + ): + """ + Generate a batch of random rigid/similarity transformations. + """ + R = TestCorrespondingPointsAlignment.random_rotation( + batch_size, dim, device=device + ) + T = torch.randn(batch_size, dim, dtype=torch.float32, device=device) + if scale: + s = torch.rand(batch_size, dtype=torch.float32, device=device) + 0.1 + else: + s = torch.ones(batch_size, dtype=torch.float32, device=device) + + return R, T, s + + @staticmethod + def generate_random_reflection(batch_size=10, dim=3, device=None): + """ + Generate a batch of reflection matrices of shape (batch_size, dim, dim), + where M_i is an identity matrix with one random entry on the + diagonal equal to -1. + """ + # randomly select one of the dimensions to reflect for each + # element in the batch + dim_to_reflect = torch.randint( + low=0, + high=dim, + size=(batch_size,), + device=device, + dtype=torch.int64, + ) + + # convert dim_to_reflect to a batch of reflection matrices M + M = torch.diag_embed( + ( + dim_to_reflect[:, None] + != torch.arange(dim, device=device, dtype=torch.float32) + ).float() + * 2 + - 1, + dim1=1, + dim2=2, + ) + + return M + + @staticmethod + def corresponding_points_alignment( + batch_size=10, + n_points=100, + dim=3, + use_pointclouds=False, + estimate_scale=False, + allow_reflection=False, + reflect=False, + ): + + device = torch.device("cuda:0") + + # initialize a ground truth point cloud + X = TestCorrespondingPointsAlignment.init_point_cloud( + batch_size=batch_size, + n_points=n_points, + dim=dim, + device=device, + use_pointclouds=use_pointclouds, + random_pcl_size=True, + ) + + # generate the true transformation + R, T, s = TestCorrespondingPointsAlignment.generate_pcl_transformation( + batch_size=batch_size, + scale=estimate_scale, + reflect=reflect, + dim=dim, + device=device, + ) + + # apply the generated transformation to the generated + # point cloud X + X_t = _apply_pcl_transformation(X, R, T, s=s) + + torch.cuda.synchronize() + + def run_corresponding_points_alignment(): + points_alignment.corresponding_points_alignment( + X, + X_t, + allow_reflection=allow_reflection, + estimate_scale=estimate_scale, + ) + torch.cuda.synchronize() + + return run_corresponding_points_alignment + + def test_corresponding_points_alignment(self, batch_size=10): + """ + Tests whether we can estimate a rigid/similarity motion between + a randomly initialized point cloud and its randomly transformed version. + + The tests are done for all possible combinations + of the following boolean flags: + - estimate_scale ... Estimate also a scaling component of + the transformation. + - reflect ... The ground truth orthonormal part of the generated + transformation is a reflection (det==-1). + - allow_reflection ... If True, the orthonormal matrix of the + estimated transformation is allowed to be + a reflection (det==-1). + - use_pointclouds ... If True, passes the Pointclouds objects + to corresponding_points_alignment. + """ + + # run this for several different point cloud sizes + for n_points in (100, 3, 2, 1, 0): + # run this for several different dimensionalities + for dim in torch.arange(2, 10): + # switches whether we should use the Pointclouds inputs + use_point_clouds_cases = ( + (True, False) if dim == 3 and n_points > 3 else (False,) + ) + for use_pointclouds in use_point_clouds_cases: + for estimate_scale in (False, True): + for reflect in (False, True): + for allow_reflection in (False, True): + self._test_single_corresponding_points_alignment( + batch_size=10, + n_points=n_points, + dim=int(dim), + use_pointclouds=use_pointclouds, + estimate_scale=estimate_scale, + reflect=reflect, + allow_reflection=allow_reflection, + ) + + def _test_single_corresponding_points_alignment( + self, + batch_size=10, + n_points=100, + dim=3, + use_pointclouds=False, + estimate_scale=False, + reflect=False, + allow_reflection=False, + ): + """ + Executes a single test for `corresponding_points_alignment` for a + specific setting of the inputs / outputs. + """ + + device = torch.device("cuda:0") + + # initialize the a ground truth point cloud + X = TestCorrespondingPointsAlignment.init_point_cloud( + batch_size=batch_size, + n_points=n_points, + dim=dim, + device=device, + use_pointclouds=use_pointclouds, + random_pcl_size=True, + ) + + # generate the true transformation + R, T, s = TestCorrespondingPointsAlignment.generate_pcl_transformation( + batch_size=batch_size, + scale=estimate_scale, + reflect=reflect, + dim=dim, + device=device, + ) + + if reflect: + # generate random reflection M and apply to the rotations + M = TestCorrespondingPointsAlignment.generate_random_reflection( + batch_size=batch_size, dim=dim, device=device + ) + R = torch.bmm(M, R) + + # apply the generated transformation to the generated + # point cloud X + X_t = _apply_pcl_transformation(X, R, T, s=s) + + # run the CorrespondingPointsAlignment algorithm + R_est, T_est, s_est = points_alignment.corresponding_points_alignment( + X, + X_t, + allow_reflection=allow_reflection, + estimate_scale=estimate_scale, + ) + + assert_error_message = ( + f"Corresponding_points_alignment assertion failure for " + f"n_points={n_points}, " + f"dim={dim}, " + f"use_pointclouds={use_pointclouds}, " + f"estimate_scale={estimate_scale}, " + f"reflect={reflect}, " + f"allow_reflection={allow_reflection}." + ) + + if reflect and not allow_reflection: + # check that all rotations have det=1 + self._assert_all_close( + torch.det(R_est), + R_est.new_ones(batch_size), + assert_error_message, + ) + + else: + # check that the estimated tranformation is the same + # as the ground truth + if n_points >= (dim + 1): + # the checks on transforms apply only when + # the problem setup is unambiguous + self._assert_all_close(R_est, R, assert_error_message) + self._assert_all_close(T_est, T, assert_error_message) + self._assert_all_close(s_est, s, assert_error_message) + + # check that the orthonormal part of the + # transformation has a correct determinant (+1/-1) + desired_det = R_est.new_ones(batch_size) + if reflect: + desired_det *= -1.0 + self._assert_all_close( + torch.det(R_est), desired_det, assert_error_message + ) + + # check that the transformed point cloud + # X matches X_t + X_t_est = _apply_pcl_transformation(X, R_est, T_est, s=s_est) + self._assert_all_close( + X_t, X_t_est, assert_error_message, atol=1e-5 + ) + + def _assert_all_close(self, a_, b_, err_message, atol=1e-6): + if isinstance(a_, Pointclouds): + a_ = a_.points_packed() + if isinstance(b_, Pointclouds): + b_ = b_.points_packed() + self.assertTrue(torch.allclose(a_, b_, atol=atol), err_message)