diff --git a/docs/zh/api/geometry.md b/docs/zh/api/geometry.md index fdd5b2af1..ab85f87c6 100644 --- a/docs/zh/api/geometry.md +++ b/docs/zh/api/geometry.md @@ -11,6 +11,7 @@ - Hypersphere - Interval - Mesh + - SDFMesh - PointCloud - Polygon - Rectangle diff --git a/docs/zh/install_setup.md b/docs/zh/install_setup.md index ad1bdbff6..d79ea887a 100644 --- a/docs/zh/install_setup.md +++ b/docs/zh/install_setup.md @@ -111,10 +111,17 @@ #### 1.4.2 安装额外功能[可选] -如需使用 `.obj`, `.ply`, `.off`, `.stl`, `.mesh`, `.node`, `.poly` and `.msh` 等复杂几何文件构建几何(计算域),以及使用加密采样等功能,则需按照下方给出的命令,安装 open3d、 -pybind11、pysdf、PyMesh 四个依赖库(上述**1.1 从 docker 镜像启动**中已安装上述依赖库)。 +PaddleScience 提供了两种复杂几何类型,如下所示: -否则无法使用 `ppsci.geometry.Mesh` 等基于复杂几何文件的 API,因此也无法运行如 [Aneurysm](./examples/aneurysm.md) 等依赖 `ppsci.geometry.Mesh` API 的复杂案例。 +| API 名称 | 支持文件类型 | 安装方式 | 使用方式 | +| -- | -- | -- | -- | +|[`ppsci.geometry.Mesh`](./api/geometry.md#ppsci.geometry.Mesh) | `.obj`, `.ply`, `.off`, `.stl`, `.mesh`, `.node`, `.poly` and `.msh`| 参考下方的 "PyMesh 安装命令"| `ppsci.geometry.Mesh(mesh_path)`| +|[`ppsci.geometry.SDFMesh`](./api/geometry.md#ppsci.geometry.SDFMesh "实验性功能") | `.stl` | `pip install warp-lang 'numpy-stl>=2.16,<2.17'` | `ppsci.geometry.SDFMesh.from_stl(stl_path)` | + +!!! warning "相关案例运行说明" + + [Bracket](./examples/aneurysm.md)、[Aneurysm](./examples/aneurysm.md) 等个别案例使用了 `ppsci.geometry.Mesh` 接口构建复杂几何,因此这些案例运行前需要按照下方给出的命令,安装 open3d、 + pybind11、pysdf、PyMesh 四个依赖库(上述**1.1 从 docker 镜像启动**中已安装上述依赖库)。如使用 `ppsci.geometry.SDFMesh` 接口构建复杂几何,则只需要安装 `warp-lang` 即可。 === "open3d 安装命令" diff --git a/mkdocs.yml b/mkdocs.yml index 7bcf39ce0..4d5d753b0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -169,6 +169,7 @@ theme: - content.tabs.link - content.action.edit - content.action.view + - content.tooltips icon: edit: material/pencil view: material/eye @@ -209,6 +210,7 @@ plugins: # Extensions markdown_extensions: + - abbr - pymdownx.critic - pymdownx.highlight: anchor_linenums: true diff --git a/ppsci/geometry/__init__.py b/ppsci/geometry/__init__.py index 768ed0581..30b4ad085 100644 --- a/ppsci/geometry/__init__.py +++ b/ppsci/geometry/__init__.py @@ -25,6 +25,7 @@ from ppsci.geometry.geometry_nd import Hypercube from ppsci.geometry.geometry_nd import Hypersphere from ppsci.geometry.mesh import Mesh +from ppsci.geometry.mesh import SDFMesh from ppsci.geometry.pointcloud import PointCloud from ppsci.geometry.timedomain import TimeDomain from ppsci.geometry.timedomain import TimeXGeometry @@ -40,6 +41,7 @@ "Hypersphere", "Interval", "Mesh", + "SDFMesh", "Polygon", "Rectangle", "Sphere", diff --git a/ppsci/geometry/mesh.py b/ppsci/geometry/mesh.py index 5a5d1d02d..d2227cf15 100644 --- a/ppsci/geometry/mesh.py +++ b/ppsci/geometry/mesh.py @@ -23,11 +23,20 @@ import numpy as np import paddle + +try: + from stl import mesh as np_mesh_module +except ModuleNotFoundError: + pass +except ImportError: + pass + from typing_extensions import Literal from ppsci.geometry import geometry from ppsci.geometry import geometry_3d from ppsci.geometry import sampler +from ppsci.geometry import sdf as sdf_module from ppsci.utils import checker from ppsci.utils import misc @@ -184,7 +193,7 @@ def is_inside(self, x): return self.pysdf.contains(x) def on_boundary(self, x): - return np.isclose(self.sdf_func(x), 0.0).flatten() + return np.isclose(self.sdf_func(x), 0.0).ravel() def translate(self, translation: np.ndarray, relative: bool = True) -> "Mesh": """Translate by given offsets. @@ -365,7 +374,7 @@ def _approximate_area( n_appr (int): Number of points for approximating area. Defaults to 10000. Returns: - np.ndarray: Approximated areas with shape of [n_faces, ]. + float: Approximation area with given criteria. """ triangle_areas = area_of_triangles(self.v0, self.v1, self.v2) triangle_probabilities = triangle_areas / np.linalg.norm(triangle_areas, ord=1) @@ -498,7 +507,7 @@ def sample_boundary( if criteria is not None: criteria_mask = criteria( *np.split(points, self.ndim, axis=1) - ).flatten() + ).ravel() points = points[criteria_mask] normal = normal[criteria_mask] @@ -547,7 +556,7 @@ def random_points(self, n, random="pseudo", criteria=None): if criteria: valid_mask &= criteria( *np.split(random_points, self.ndim, axis=1) - ).flatten() + ).ravel() valid_points = random_points[valid_mask] _nvalid += len(valid_points) @@ -662,6 +671,594 @@ def __str__(self) -> str: ) +class SDFMesh(geometry.Geometry): + """Class for SDF geometry, a kind of implicit surface mesh. + + Args: + vectors (np.ndarray): Vectors of triangles of mesh with shape [M, 3, 3]. + normals (np.ndarray): Unit normals of each triangle face with shape [M, 3]. + sdf_func (Callable[[np.ndarray, bool], np.ndarray]): Signed distance function + of the triangle mesh. + + Examples: + >>> import ppsci + >>> geom = ppsci.geometry.SDFMesh.from_stl("/path/to/mesh.stl") # doctest: +SKIP + """ + + eps = 1e-6 + + def __init__( + self, + vectors: np.ndarray, + normals: np.ndarray, + sdf_func: Callable[[np.ndarray, bool], np.ndarray], + ): + if vectors.shape[1:] != (3, 3): + raise ValueError( + f"The shape of `vectors` must be [M, 3, 3], but got {vectors.shape}" + ) + if normals.shape[1] != 3: + raise ValueError( + f"The shape of `normals` must be [M, 3], but got {normals.shape}" + ) + self.vectors = vectors + self.face_normal = normals + self.sdf_func = sdf_func # overwrite sdf_func + self.bounds = ( + ((np.min(self.vectors[:, :, 0])), np.max(self.vectors[:, :, 0])), + ((np.min(self.vectors[:, :, 1])), np.max(self.vectors[:, :, 1])), + ((np.min(self.vectors[:, :, 2])), np.max(self.vectors[:, :, 2])), + ) + self.ndim = 3 + super().__init__( + self.vectors.shape[-1], + (np.amin(self.vectors, axis=(0, 1)), np.amax(self.vectors, axis=(0, 1))), + np.inf, + ) + + @property + def v0(self) -> np.ndarray: + return self.vectors[:, 0] + + @property + def v1(self) -> np.ndarray: + return self.vectors[:, 1] + + @property + def v2(self) -> np.ndarray: + return self.vectors[:, 2] + + @classmethod + def from_stl(cls, mesh_file: str) -> "SDFMesh": + """Instantiate SDFMesh from given mesh file. + + Args: + mesh_file (str): Path to triangle mesh file. + + Returns: + SDFMesh: Instantiated ppsci.geometry.SDFMesh object. + + Examples: + >>> import ppsci + >>> import pymesh # doctest: +SKIP + >>> import numpy as np # doctest: +SKIP + >>> box = pymesh.generate_box_mesh(np.array([0, 0, 0]), np.array([1, 1, 1])) # doctest: +SKIP + >>> pymesh.save_mesh("box.stl", box) # doctest: +SKIP + >>> mesh = ppsci.geometry.SDFMesh.from_stl("box.stl") # doctest: +SKIP + >>> print(sdfmesh.vectors.shape) # doctest: +SKIP + (12, 3, 3) + """ + # check if pymesh is installed when using Mesh Class + if not checker.dynamic_import_to_globals(["stl"]): + raise ImportError( + "Could not import stl python package. " + "Please install numpy-stl with: pip install 'numpy-stl>=2.16,<2.17'" + ) + + np_mesh_obj = np_mesh_module.Mesh.from_file(mesh_file) + return cls( + np_mesh_obj.vectors, + np_mesh_obj.get_unit_normals(), + make_sdf(np_mesh_obj.vectors), + ) + + def sdf_func( + self, points: np.ndarray, compute_sdf_derivatives: bool = False + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """Compute signed distance field. + + Args: + points (np.ndarray): The coordinate points used to calculate the SDF value, + the shape is [N, 3] + compute_sdf_derivatives (bool): Whether to compute SDF derivatives. + Defaults to False. + + Returns: + Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + If compute_sdf_derivatives is True, then return both SDF values([N, 1]) + and their derivatives([N, 3]); otherwise only return SDF values([N, 1]). + + NOTE: This function usually returns ndarray with negative values, because + according to the definition of SDF, the SDF value of the coordinate point inside + the object(interior points) is negative, the outside is positive, and the edge + is 0. Therefore, when used for weighting, a negative sign is often added before + the result of this function. + """ + # normalize triangles + x_min, y_min, z_min = np.min(points, axis=0) + x_max, y_max, z_max = np.max(points, axis=0) + max_dis = max(max((x_max - x_min), (y_max - y_min)), (z_max - z_min)) + store_triangles = np.array(self.vectors, dtype=np.float64) + store_triangles[:, :, 0] -= x_min + store_triangles[:, :, 1] -= y_min + store_triangles[:, :, 2] -= z_min + store_triangles *= 1 / max_dis + store_triangles = store_triangles.reshape([-1, 3]) + + # normalize query points + points = points.copy() + points[:, 0] -= x_min + points[:, 1] -= y_min + points[:, 2] -= z_min + points *= 1 / max_dis + points = points.astype(np.float64).ravel() + + # compute sdf values for query points + sdf = sdf_module.signed_distance_field( + store_triangles, + np.arange((store_triangles.shape[0])), + points, + include_hit_points=compute_sdf_derivatives, + ) + if compute_sdf_derivatives: + sdf, hit_points = sdf + + sdf = sdf.numpy() # [N] + sdf = np.expand_dims(max_dis * sdf, axis=1) # [N, 1] + + if compute_sdf_derivatives: + hit_points = hit_points.numpy() # [N, 3] + # Gradient of SDF is the unit vector from the query point to the hit point. + sdf_derives = hit_points - points + sdf_derives /= np.linalg.norm(sdf_derives, axis=1, keepdims=True) + return sdf, sdf_derives + + return sdf + + def is_inside(self, x): + # NOTE: point on boundary is included + return np.less(self.sdf_func(x), 0.0).ravel() + + def on_boundary(self, x: np.ndarray, normal: np.ndarray) -> np.ndarray: + x_plus = x + self.eps * normal + x_minus = x - self.eps * normal + + sdf_x_plus = self.sdf_func(x_plus) + sdf_x_minus = self.sdf_func(x_minus) + mask_on_boundary = np.less_equal(sdf_x_plus * sdf_x_minus, 0) + return mask_on_boundary.ravel() + + def translate(self, translation: np.ndarray) -> "SDFMesh": + """Translate by given offsets. + + NOTE: This API generate a completely new Mesh object with translated geometry, + without modifying original Mesh object inplace. + + Args: + translation (np.ndarray): Translation offsets, numpy array of shape (3,): + [offset_x, offset_y, offset_z]. + + Returns: + Mesh: Translated Mesh object. + + Examples: + >>> import ppsci + >>> import pymesh # doctest: +SKIP + >>> mesh = ppsci.geometry.SDFMesh.from_stl('/path/to/mesh.stl') # doctest: +SKIP + >>> mesh = mesh.translate(np.array([1, -1, 2])) # doctest: +SKIP + """ + new_vectors = self.vectors + translation.reshape([1, 1, 3]) + + return SDFMesh( + new_vectors, + self.face_normal, + make_sdf(new_vectors), + ) + + def scale(self, scale: float) -> "SDFMesh": + """Scale by given scale coefficient and center coordinate. + + NOTE: This API generate a completely new Mesh object with scaled geometry, + without modifying original Mesh object inplace. + + Args: + scale (float): Scale coefficient. + + Returns: + Mesh: Scaled Mesh object. + + Examples: + >>> import ppsci + >>> import pymesh # doctest: +SKIP + >>> mesh = ppsci.geometry.SDFMesh.from_stl('/path/to/mesh.stl') # doctest: +SKIP + >>> mesh = mesh.scale(np.array([1.3, 1.5, 2.0])) # doctest: +SKIP + """ + new_vectors = self.vectors * scale + return SDFMesh( + new_vectors, + self.face_normal, + make_sdf(new_vectors), + ) + + def uniform_boundary_points(self, n: int): + """Compute the equi-spaced points on the boundary.""" + raise NotImplementedError( + "'uniform_boundary_points' is not available in SDFMesh." + ) + + def inflated_random_points(self, n, distance, random="pseudo", criteria=None): + raise NotImplementedError( + "'inflated_random_points' is not available in SDFMesh." + ) + + def _approximate_area( + self, + random: Literal["pseudo"] = "pseudo", + criteria: Optional[Callable] = None, + n_appr: int = 10000, + ) -> float: + """Approximate area with given `criteria` and `n_appr` points by Monte Carlo + algorithm. + + Args: + random (str, optional): Random method. Defaults to "pseudo". + criteria (Optional[Callable]): Criteria function. Defaults to None. + n_appr (int): Number of points for approximating area. Defaults to 10000. + + Returns: + float: Approximation area with given criteria. + """ + triangle_areas = area_of_triangles(self.v0, self.v1, self.v2) + triangle_probabilities = triangle_areas / np.linalg.norm(triangle_areas, ord=1) + triangle_index = np.arange(triangle_probabilities.shape[0]) + npoint_per_triangle = np.random.choice( + triangle_index, n_appr, p=triangle_probabilities + ) + npoint_per_triangle, _ = np.histogram( + npoint_per_triangle, + np.arange(triangle_probabilities.shape[0] + 1) - 0.5, + ) + + aux_points = [] + aux_normals = [] + appr_areas = [] + + for i, npoint in enumerate(npoint_per_triangle): + if npoint == 0: + continue + # sample points for computing criteria mask if criteria is given + points_at_triangle_i = sample_in_triangle( + self.v0[i], self.v1[i], self.v2[i], npoint, random + ) + normal_at_triangle_i = np.tile( + self.face_normal[i].reshape(1, 3), (npoint, 1) + ) + aux_points.append(points_at_triangle_i) + aux_normals.append(normal_at_triangle_i) + appr_areas.append( + np.full( + (npoint, 1), triangle_areas[i] / npoint, paddle.get_default_dtype() + ) + ) + + aux_points = np.concatenate(aux_points, axis=0) # [n_appr, 3] + aux_normals = np.concatenate(aux_normals, axis=0) # [n_appr, 3] + appr_areas = np.concatenate(appr_areas, axis=0) # [n_appr, 1] + valid_mask = self.on_boundary(aux_points, aux_normals)[:, None] + # set invalid area to 0 by computing criteria mask with auxiliary points + if criteria is not None: + criteria_mask = criteria(*np.split(aux_points, self.ndim, 1)) + assert valid_mask.shape == criteria_mask.shape + valid_mask = np.logical_and(valid_mask, criteria_mask) + + appr_areas *= valid_mask + + return appr_areas.sum() + + def random_boundary_points( + self, n, random="pseudo" + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + triangle_area = area_of_triangles(self.v0, self.v1, self.v2) + triangle_prob = triangle_area / np.linalg.norm(triangle_area, ord=1) + npoint_per_triangle = np.random.choice( + np.arange(len(triangle_prob)), n, p=triangle_prob + ) + npoint_per_triangle, _ = np.histogram( + npoint_per_triangle, np.arange(len(triangle_prob) + 1) - 0.5 + ) + + points = [] + normal = [] + areas = [] + for i, npoint in enumerate(npoint_per_triangle): + if npoint == 0: + continue + points_at_triangle_i = sample_in_triangle( + self.v0[i], self.v1[i], self.v2[i], npoint, random + ) + normal_at_triangle_i = np.tile(self.face_normal[i], (npoint, 1)).astype( + paddle.get_default_dtype() + ) + areas_at_triangle_i = np.full( + (npoint, 1), + triangle_area[i] / npoint, + dtype=paddle.get_default_dtype(), + ) + + points.append(points_at_triangle_i) + normal.append(normal_at_triangle_i) + areas.append(areas_at_triangle_i) + + points = np.concatenate(points, axis=0) + normal = np.concatenate(normal, axis=0) + areas = np.concatenate(areas, axis=0) + + return points, normal, areas + + def sample_boundary( + self, + n: int, + random: Literal["pseudo"] = "pseudo", + criteria: Optional[Callable[..., np.ndarray]] = None, + evenly: bool = False, + inflation_dist: Union[float, Tuple[float, ...]] = None, + ) -> Dict[str, np.ndarray]: + # TODO(sensen): Support for time-dependent points(repeat data in time) + if inflation_dist is not None: + raise NotImplementedError("Not implemented yet") + else: + if evenly: + raise ValueError( + "Can't sample evenly on mesh now, please set evenly=False." + ) + _size, _ntry, _nsuc = 0, 0, 0 + all_points = [] + all_normal = [] + while _size < n: + points, normal, _ = self.random_boundary_points(n, random) + valid_mask = self.on_boundary(points, normal) + + if criteria is not None: + criteria_mask = criteria( + *np.split(points, self.ndim, axis=1) + ).ravel() + assert valid_mask.shape == criteria_mask.shape + valid_mask = np.logical_and(valid_mask, criteria_mask) + + points = points[valid_mask] + normal = normal[valid_mask] + + if len(points) > n - _size: + points = points[: n - _size] + normal = normal[: n - _size] + + all_points.append(points) + all_normal.append(normal) + + _size += len(points) + _ntry += 1 + if len(points) > 0: + _nsuc += 1 + + if _ntry >= 1000 and _nsuc == 0: + raise ValueError( + "Sample boundary points failed, " + "please check correctness of geometry and given criteria." + ) + + all_points = np.concatenate(all_points, axis=0) + all_normal = np.concatenate(all_normal, axis=0) + _appr_area = self._approximate_area(random, criteria) + all_areas = np.full((n, 1), _appr_area / n, paddle.get_default_dtype()) + + x_dict = misc.convert_to_dict(all_points, self.dim_keys) + normal_dict = misc.convert_to_dict( + all_normal, [f"normal_{key}" for key in self.dim_keys if key != "t"] + ) + area_dict = misc.convert_to_dict(all_areas, ["area"]) + return {**x_dict, **normal_dict, **area_dict} + + def random_points(self, n, random="pseudo", criteria=None): + _size = 0 + all_points = [] + cuboid = geometry_3d.Cuboid( + [bound[0] for bound in self.bounds], + [bound[1] for bound in self.bounds], + ) + _nsample, _nvalid = 0, 0 + while _size < n: + random_points = cuboid.random_points(n, random) + valid_mask = self.is_inside(random_points) + + if criteria: + criteria_mask = criteria( + *np.split(random_points, self.ndim, axis=1) + ).ravel() + assert valid_mask.shape == criteria_mask.shape + valid_mask = np.logical_and(valid_mask, criteria_mask) + + valid_points = random_points[valid_mask] + _nvalid += len(valid_points) + + if len(valid_points) > n - _size: + valid_points = valid_points[: n - _size] + + all_points.append(valid_points) + _size += len(valid_points) + _nsample += n + + all_points = np.concatenate(all_points, axis=0) + cuboid_volume = np.prod([b[1] - b[0] for b in self.bounds]) + all_areas = np.full( + (n, 1), cuboid_volume * (_nvalid / _nsample) / n, paddle.get_default_dtype() + ) + return all_points, all_areas + + def sample_interior( + self, + n: int, + random: Literal["pseudo"] = "pseudo", + criteria: Optional[Callable[..., np.ndarray]] = None, + compute_sdf_derivatives: bool = False, + ): + """Sample random points in the geometry and return those meet criteria.""" + points, areas = self.random_points(n, random, criteria) + + x_dict = misc.convert_to_dict(points, self.dim_keys) + area_dict = misc.convert_to_dict(areas, ("area",)) + + sdf = self.sdf_func(points, compute_sdf_derivatives) + if compute_sdf_derivatives: + sdf, sdf_derives = sdf + + # NOTE: Negate sdf because weight should be positive. + sdf_dict = misc.convert_to_dict(-sdf, ("sdf",)) + + sdf_derives_dict = {} + if compute_sdf_derivatives: + # NOTE: Negate sdf derivatives + sdf_derives_dict = misc.convert_to_dict( + -sdf_derives, tuple(f"sdf__{key}" for key in self.dim_keys) + ) + + return {**x_dict, **area_dict, **sdf_dict, **sdf_derives_dict} + + def union(self, other: "SDFMesh"): + new_vectors = np.concatenate([self.vectors, other.vectors], axis=0) + new_normals = np.concatenate([self.face_normal, other.face_normal], axis=0) + + def make_union_new_sdf(sdf_func1, sdf_func2): + def new_sdf_func(points: np.ndarray, compute_sdf_derivatives: bool = False): + # Invert definition of sdf to make boolean operation accurate + # see: https://iquilezles.org/articles/interiordistance/ + sdf_self = sdf_func1(points, compute_sdf_derivatives) + sdf_other = sdf_func2(points, compute_sdf_derivatives) + if compute_sdf_derivatives: + sdf_self, sdf_derives_self = sdf_self + sdf_other, sdf_derives_other = sdf_other + + computed_sdf = -np.maximum(-sdf_self, -sdf_other) + + if compute_sdf_derivatives: + computed_sdf_derives = -np.where( + sdf_self < sdf_other, + sdf_derives_self, + sdf_derives_other, + ) + return computed_sdf, computed_sdf_derives + + return computed_sdf + + return new_sdf_func + + return SDFMesh( + new_vectors, + new_normals, + make_union_new_sdf(self.sdf_func, other.sdf_func), + ) + + def __or__(self, other: "SDFMesh"): + return self.union(other) + + def __add__(self, other: "SDFMesh"): + return self.union(other) + + def difference(self, other: "SDFMesh"): + new_vectors = np.concatenate([self.vectors, other.vectors], axis=0) + new_normals = np.concatenate([self.face_normal, -other.face_normal], axis=0) + + def make_difference_new_sdf(sdf_func1, sdf_func2): + def new_sdf_func(points: np.ndarray, compute_sdf_derivatives: bool = False): + # Invert definition of sdf to make boolean operation accurate + # see: https://iquilezles.org/articles/interiordistance/ + sdf_self = sdf_func1(points, compute_sdf_derivatives) + sdf_other = sdf_func2(points, compute_sdf_derivatives) + if compute_sdf_derivatives: + sdf_self, sdf_derives_self = sdf_self + sdf_other, sdf_derives_other = sdf_other + + computed_sdf = -np.minimum(-sdf_self, sdf_other) + + if compute_sdf_derivatives: + computed_sdf_derives = np.where( + -sdf_self < sdf_other, + -sdf_derives_self, + sdf_derives_other, + ) + return computed_sdf, computed_sdf_derives + + return computed_sdf + + return new_sdf_func + + return SDFMesh( + new_vectors, + new_normals, + make_difference_new_sdf(self.sdf_func, other.sdf_func), + ) + + def __sub__(self, other: "SDFMesh"): + return self.difference(other) + + def intersection(self, other: "SDFMesh"): + new_vectors = np.concatenate([self.vectors, other.vectors], axis=0) + new_normals = np.concatenate([self.face_normal, other.face_normal], axis=0) + + def make_intersection_new_sdf(sdf_func1, sdf_func2): + def new_sdf_func(points: np.ndarray, compute_sdf_derivatives: bool = False): + # Invert definition of sdf to make boolean operation accurate + # see: https://iquilezles.org/articles/interiordistance/ + sdf_self = sdf_func1(points, compute_sdf_derivatives) + sdf_other = sdf_func2(points, compute_sdf_derivatives) + if compute_sdf_derivatives: + sdf_self, sdf_derives_self = sdf_self + sdf_other, sdf_derives_other = sdf_other + + computed_sdf = -np.minimum(-sdf_self, -sdf_other) + + if compute_sdf_derivatives: + computed_sdf_derives = np.where( + sdf_self > sdf_other, + -sdf_derives_self, + -sdf_derives_other, + ) + return computed_sdf, computed_sdf_derives + + return computed_sdf + + return new_sdf_func + + return SDFMesh( + new_vectors, + new_normals, + make_intersection_new_sdf(self.sdf_func, other.sdf_func), + ) + + def __and__(self, other: "SDFMesh"): + return self.intersection(other) + + def __str__(self) -> str: + """Return the name of class""" + return ", ".join( + [ + self.__class__.__name__, + f"num_faces = {self.vectors.shape[0]}", + f"bounds = {self.bounds}", + f"dim_keys = {self.dim_keys}", + ] + ) + + def area_of_triangles(v0, v1, v2): """Ref https://math.stackexchange.com/questions/128991/how-to-calculate-the-area-of-a-3d-triangle @@ -713,15 +1310,15 @@ def sample_in_triangle(v0, v1, v2, n, random="pseudo", criteria=None): xs, ys, zs = [], [], [] _size = 0 while _size < n: - r1 = sampler.sample(n, 1, random).flatten() - r2 = sampler.sample(n, 1, random).flatten() + r1 = sampler.sample(n, 1, random).ravel() + r2 = sampler.sample(n, 1, random).ravel() s1 = np.sqrt(r1) x = v0[0] * (1.0 - s1) + v1[0] * (1.0 - r2) * s1 + v2[0] * r2 * s1 y = v0[1] * (1.0 - s1) + v1[1] * (1.0 - r2) * s1 + v2[1] * r2 * s1 z = v0[2] * (1.0 - s1) + v1[2] * (1.0 - r2) * s1 + v2[2] * r2 * s1 if criteria is not None: - criteria_mask = criteria(x, y, z).flatten() + criteria_mask = criteria(x, y, z).ravel() x = x[criteria_mask] y = y[criteria_mask] z = z[criteria_mask] @@ -741,3 +1338,48 @@ def sample_in_triangle(v0, v1, v2, n, random="pseudo", criteria=None): zs = np.concatenate(zs, axis=0) return np.stack([xs, ys, zs], axis=1) + + +def make_sdf(vectors: np.ndarray): + def sdf_func(points: np.ndarray, compute_sdf_derivatives=False): + points = points.copy() + x_min, y_min, z_min = np.min(points, axis=0) + x_max, y_max, z_max = np.max(points, axis=0) + max_dis = max(max((x_max - x_min), (y_max - y_min)), (z_max - z_min)) + store_triangles = vectors.copy() + store_triangles[:, :, 0] -= x_min + store_triangles[:, :, 1] -= y_min + store_triangles[:, :, 2] -= z_min + store_triangles *= 1 / max_dis + store_triangles = store_triangles.reshape([-1, 3]) + points[:, 0] -= x_min + points[:, 1] -= y_min + points[:, 2] -= z_min + points *= 1 / max_dis + points = points.astype(np.float64).ravel() + + # compute sdf values + sdf = sdf_module.signed_distance_field( + store_triangles, + np.arange((store_triangles.shape[0])), + points, + include_hit_points=compute_sdf_derivatives, + ) + if compute_sdf_derivatives: + sdf, sdf_derives = sdf + + sdf = sdf.numpy() + sdf = np.expand_dims(max_dis * sdf, axis=1) + + if compute_sdf_derivatives: + sdf_derives = sdf_derives.numpy().reshape(-1) + sdf_derives = -(sdf_derives - points) + sdf_derives = np.reshape(sdf_derives, (sdf_derives.shape[0] // 3, 3)) + sdf_derives = sdf_derives / np.linalg.norm( + sdf_derives, axis=1, keepdims=True + ) + return sdf, sdf_derives + + return sdf + + return sdf_func diff --git a/ppsci/geometry/sdf.py b/ppsci/geometry/sdf.py new file mode 100644 index 000000000..bb3e26054 --- /dev/null +++ b/ppsci/geometry/sdf.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F401 + +# modified from: https://github.com/NVIDIA/modulus/blob/main/modulus/utils/sdf.py + +from __future__ import annotations + +import importlib.util +from typing import Tuple +from typing import overload + +from numpy import ndarray + +try: + import warp as wp + + @wp.kernel + def _bvh_query_distance( + mesh: wp.uint64, + points: wp.array(dtype=wp.vec3f), + max_dist: wp.float32, + sdf: wp.array(dtype=wp.float32), + sdf_hit_point: wp.array(dtype=wp.vec3f), + sdf_hit_point_id: wp.array(dtype=wp.int32), + ): + + """ + Computes the signed distance from each point in the given array `points` + to the mesh represented by `mesh`,within the maximum distance `max_dist`, + and stores the result in the array `sdf`. + + Parameters: + mesh (wp.uint64): The identifier of the mesh. + points (wp.array): An array of 3D points for which to compute the + signed distance. + max_dist (wp.float32): The maximum distance within which to search + for the closest point on the mesh. + sdf (wp.array): An array to store the computed signed distances. + sdf_hit_point (wp.array): An array to store the computed hit points. + sdf_hit_point_id (wp.array): An array to store the computed hit point ids. + + Returns: + None + """ + tid = wp.tid() + + res = wp.mesh_query_point_sign_normal(mesh, points[tid], max_dist) + + mesh_ = wp.mesh_get(mesh) + + p0 = mesh_.points[mesh_.indices[3 * res.face + 0]] + p1 = mesh_.points[mesh_.indices[3 * res.face + 1]] + p2 = mesh_.points[mesh_.indices[3 * res.face + 2]] + + p_closest = res.u * p0 + res.v * p1 + (1.0 - res.u - res.v) * p2 + + sdf[tid] = res.sign * wp.abs(wp.length(points[tid] - p_closest)) + sdf_hit_point[tid] = p_closest + sdf_hit_point_id[tid] = res.face + +except ModuleNotFoundError: + pass +except Exception: + raise + + +@overload +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: ndarray, + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = False, + include_hit_points_id: bool = False, +) -> wp.array: + ... + + +@overload +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: ndarray, + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = True, + include_hit_points_id: bool = False, +) -> Tuple[wp.array, wp.array]: + ... + + +@overload +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: ndarray, + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = False, + include_hit_points_id: bool = True, +) -> Tuple[wp.array, wp.array]: + ... + + +@overload +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: ndarray, + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = True, + include_hit_points_id: bool = True, +) -> Tuple[wp.array, wp.array, wp.array]: + ... + + +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: ndarray, + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = False, + include_hit_points_id: bool = False, +) -> wp.array: + """ + Computes the signed distance field (SDF) for a given mesh and input points. + + Args: + mesh_vertices (list[tuple[float, float, float]]): List of vertices defining the mesh. + mesh_indices (list[tuple[int, int, int]]): List of indices defining the triangles of the mesh. + input_points (list[tuple[float, float, float]]): List of input points for which to compute the SDF. + max_dist (float, optional): Maximum distance within which to search for + the closest point on the mesh. Default is 1e8. + include_hit_points (bool, optional): Whether to include hit points in + the output. Default is False. + include_hit_points_id (bool, optional): Whether to include hit point + IDs in the output. Default is False. + + Returns: + wp.array: An array containing the computed signed distance field. + + Example: + >>> mesh_vertices = [(0, 0, 0), (1, 0, 0), (0, 1, 0)] + >>> mesh_indices = np.array((0, 1, 2)) + >>> input_points = [(0.5, 0.5, 0.5)] + >>> signed_distance_field(mesh_vertices, mesh_indices, input_points).numpy() + Module modulus.utils.sdf load on device 'cuda:0' took ... + array([0.5], dtype=float32) + """ + if not importlib.util.find_spec("warp"): + raise ModuleNotFoundError("Please install warp with: pip install warp-lang") + + wp.init() + mesh = wp.Mesh( + wp.array(mesh_vertices, dtype=wp.vec3), + wp.array(mesh_indices, dtype=wp.int32), + ) + + sdf_points = wp.array(input_points, dtype=wp.vec3) + + sdf = wp.zeros(shape=sdf_points.shape, dtype=wp.float32) + sdf_hit_point = wp.zeros(shape=sdf_points.shape, dtype=wp.vec3f) + sdf_hit_point_id = wp.zeros(shape=sdf_points.shape, dtype=wp.int32) + + wp.launch( + kernel=_bvh_query_distance, + dim=len(sdf_points), + inputs=[ + mesh.id, + sdf_points, + max_dist, + sdf, + sdf_hit_point, + sdf_hit_point_id, + ], + ) + + if include_hit_points and include_hit_points_id: + return (sdf, sdf_hit_point, sdf_hit_point_id) + elif include_hit_points: + return (sdf, sdf_hit_point) + elif include_hit_points_id: + return (sdf, sdf_hit_point_id) + else: + return sdf