diff --git a/environment.yml b/environment.yml index bc78942a2..2603b0c82 100644 --- a/environment.yml +++ b/environment.yml @@ -15,6 +15,7 @@ dependencies: - pptree - qpax - rod >= 0.3.3 + - trimesh - typing_extensions # python<3.12 # ==================================== # Optional dependencies from setup.cfg diff --git a/pyproject.toml b/pyproject.toml index 044ac8464..e335001d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ dependencies = [ "qpax", "rod >= 0.3.3", "typing_extensions ; python_version < '3.12'", + "trimesh", + "manifold3d", ] [project.optional-dependencies] @@ -67,7 +69,7 @@ testing = [ "idyntree >= 12.2.1", "pytest >=6.0", "pytest-icdiff", - "robot-descriptions", + "robot-descriptions" ] viz = [ "lxml", diff --git a/src/jaxsim/parsers/descriptions/__init__.py b/src/jaxsim/parsers/descriptions/__init__.py index 9e180c155..ff3bf631d 100644 --- a/src/jaxsim/parsers/descriptions/__init__.py +++ b/src/jaxsim/parsers/descriptions/__init__.py @@ -1,4 +1,10 @@ -from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision +from .collision import ( + BoxCollision, + CollidablePoint, + CollisionShape, + MeshCollision, + SphereCollision, +) from .joint import JointDescription, JointGenericAxis, JointType from .link import LinkDescription from .model import ModelDescription diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 31ae17b97..7815af488 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -154,3 +154,22 @@ def __eq__(self, other: BoxCollision) -> bool: return False return hash(self) == hash(other) + + +@dataclasses.dataclass +class MeshCollision(CollisionShape): + center: jtp.VectorLike + + def __hash__(self) -> int: + return hash( + ( + hash(tuple(self.center.tolist())), + hash(self.collidable_points), + ) + ) + + def __eq__(self, other: MeshCollision) -> bool: + if not isinstance(other, MeshCollision): + return False + + return hash(self) == hash(other) diff --git a/src/jaxsim/parsers/rod/meshes.py b/src/jaxsim/parsers/rod/meshes.py new file mode 100644 index 000000000..5e2907de0 --- /dev/null +++ b/src/jaxsim/parsers/rod/meshes.py @@ -0,0 +1,119 @@ +import numpy as np +import trimesh + +VALID_AXIS = {"x": 0, "y": 1, "z": 2} + + +def parse_object_mapping_object(obj: trimesh.Trimesh | dict) -> trimesh.Trimesh: + if isinstance(obj, trimesh.Trimesh): + return obj + elif isinstance(obj, dict): + if obj["type"] == "box": + return trimesh.creation.box(extents=obj["extents"]) + elif obj["type"] == "sphere": + return trimesh.creation.icosphere(subdivisions=4, radius=obj["radius"]) + else: + raise ValueError(f"Invalid object type {obj['type']}") + else: + raise ValueError("Invalid object type") + + +def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray: + """ + Extracts the vertices of a mesh as points. + """ + return mesh.vertices + + +def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray: + """ + Extracts N random points from the surface of a mesh. + + Args: + mesh: The mesh from which to extract points. + n: The number of points to extract. + + Returns: + The extracted points (N x 3 array). + """ + + return mesh.sample(n) + + +def extract_points_uniform_surface_sampling( + mesh: trimesh.Trimesh, n: int +) -> np.ndarray: + """ + Extracts N uniformly sampled points from the surface of a mesh. + + Args: + mesh: The mesh from which to extract points. + n: The number of points to extract. + + Returns: + The extracted points (N x 3 array). + """ + + return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0] + + +def extract_points_select_points_over_axis( + mesh: trimesh.Trimesh, axis: str, direction: str, n: int +) -> np.ndarray: + """ + Extracts N points from a mesh along a specified axis. The points are selected based on their position along the axis. + + Args: + mesh: The mesh from which to extract points. + axis: The axis along which to extract points. + direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower". + n: The number of points to extract. + + Returns: + The extracted points (N x 3 array). + """ + + dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]} + arr = mesh.vertices + + # Sort the array in ascending order + arr.sort(axis=0) # Sort rows lexicographically first, then columnar + sorted_arr = arr[dirs[direction]] + return sorted_arr + + +def extract_points_aap( + mesh: trimesh.Trimesh, + axis: str, + upper: float | None = None, + lower: float | None = None, +) -> np.ndarray: + """ + Extracts points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis. + + Args: + mesh: The mesh from which to extract points. + axis: The axis along which to extract points. + upper: The upper bound of the range. + lower: The lower bound of the range. + + Returns: + The extracted points (N x 3 array). + + Raises: + AssertionError: If the lower bound is greater than the upper bound. + """ + + # Check bounds + upper = upper if upper is not None else np.inf + lower = lower if lower is not None else -np.inf + assert lower < upper, "Invalid bounds for axis-aligned plane" + + # Logic + + points = mesh.vertices[ + (mesh.vertices[:, VALID_AXIS[axis]] >= lower) + & (mesh.vertices[:, VALID_AXIS[axis]] <= upper) + ] + + return points diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 082a4e0df..9947c43a7 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -337,6 +337,18 @@ def extract_model_data( collisions.append(sphere_collision) + if collision.geometry.mesh is not None and int( + os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0") + ): + logging.warning("Mesh collision support is still experimental.") + mesh_collision = utils.create_mesh_collision( + collision=collision, + link_description=links_dict[link.name], + method=utils.meshes.extract_points_vertices, + ) + + collisions.append(mesh_collision) + return SDFData( model_name=sdf_model.name, link_descriptions=links, diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index 242bb0a2e..74ecb98a3 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,12 +1,21 @@ import os +import pathlib +from collections.abc import Callable +from typing import TypeVar import numpy as np import numpy.typing as npt import rod +import trimesh +from rod.utils.resolve_uris import resolve_local_uri import jaxsim.typing as jtp +from jaxsim import logging from jaxsim.math import Adjoint, Inertia from jaxsim.parsers import descriptions +from jaxsim.parsers.rod import meshes + +MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray]) def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: @@ -202,3 +211,45 @@ def fibonacci_sphere(samples: int) -> npt.NDArray: return descriptions.SphereCollision( collidable_points=collidable_points, center=center_wrt_link ) + + +def create_mesh_collision( + collision: rod.Collision, + link_description: descriptions.LinkDescription, + method: MeshMappingMethod = None, +) -> descriptions.MeshCollision: + + file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri)) + _file_type = file.suffix.replace(".", "") + mesh = trimesh.load_mesh(file, file_type=_file_type) + + if mesh.is_empty: + raise RuntimeError(f"Failed to process '{file}' with trimesh") + + mesh.apply_scale(collision.geometry.mesh.scale) + logging.info( + msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{_file_type}'" + ) + + if method is None: + method = meshes.VertexExtraction() + logging.debug("Using default Vertex Extraction method for mesh wrapping") + else: + logging.debug(f"Using method {method} for mesh wrapping") + + points = method(mesh=mesh) + logging.debug(f"Extracted {len(points)} points from mesh") + + W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4) + # Extract translation from transformation matrix + W_p_L = H[:3, 3] + mesh_points_wrt_link = points @ H[:3, :3].T + W_p_L + collidable_points = [ + descriptions.CollidablePoint( + parent_link=link_description, + position=point, + enabled=True, + ) + for point in mesh_points_wrt_link + ] + return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L) diff --git a/tests/conftest.py b/tests/conftest.py index 2ffb1c06d..a6a824d48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -244,9 +244,14 @@ def jaxsim_model_ur10() -> js.model.JaxSimModel: """ import robot_descriptions.ur10_description + from robot_descriptions._package_dirs import get_package_dirs model_urdf_path = pathlib.Path(robot_descriptions.ur10_description.URDF_PATH) + os.environ["GAZEBO_MODEL_PATH"] = os.environ.get( + "GAZEBO_MODEL_PATH", "" + ) + ":".join(get_package_dirs(robot_descriptions.ur10_description)) + return build_jaxsim_model(model_description=model_urdf_path) diff --git a/tests/test_meshes.py b/tests/test_meshes.py new file mode 100644 index 000000000..7765ae15b --- /dev/null +++ b/tests/test_meshes.py @@ -0,0 +1,97 @@ +import trimesh + +from jaxsim.parsers.rod import meshes + + +def test_mesh_wrapping_vertex_extraction(): + """Test the vertex extraction method on different meshes. + 1. A simple box + 2. A sphere + """ + + # Test 1: A simple box + # First, create a box with origin at (0,0,0) and extents (3,3,3) -> points span from -1.5 to 1.5 on axis + mesh = trimesh.creation.box( + extents=[3.0, 3.0, 3.0], + ) + points = meshes.extract_points_vertices(mesh=mesh) + assert len(points) == len(mesh.vertices) + + # Test 2: A sphere + # The sphere is centered at the origin and has a radius of 1.0 + mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) + points = meshes.extract_points_vertices(mesh=mesh) + assert len(points) == len(mesh.vertices) + + +def test_mesh_wrapping_aap(): + """Test the AAP wrapping method on different meshes. + 1. A simple box + 1.1: Remove all points above x=0.0 + 1.2: Remove all points below y=0.0 + 2. A sphere + """ + + # Test 1.1: Remove all points above x=0.0 + # The expected result is that the number of points is halved + # First, create a box with origin at (0,0,0) and extents (3,3,3) -> points span from -1.5 to 1.5 on axis + mesh = trimesh.creation.box( + extents=[3.0, 3.0, 3.0], + ) + points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0) + assert len(points) == len(mesh.vertices) // 2 + assert all(points[:, 0] > 0.0) + + # Test 1.2: Remove all points below y=0.0 + # Again, the expected result is that the number of points is halved + points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0) + assert len(points) == len(mesh.vertices) // 2 + assert all(points[:, 1] < 0.0) + + # Test 2: A sphere + # The sphere is centered at the origin and has a radius of 1.0. Points are expected to be halved + mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) + # Remove all points above y=0.0 + points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0) + assert all(points[:, 1] >= 0.0) + assert len(points) < len(mesh.vertices) + + +def test_mesh_wrapping_points_over_axis(): + """ + Test the points over axis method on different meshes. + 1. A simple box + 1.1: Select 10 points from the lower end of the x-axis + 1.2: Select 10 points from the higher end of the y-axis + 2. A sphere + """ + + # Test 1.1: Remove 10 points from the lower end of the x-axis + # First, create a box with origin at (0,0,0) and extents (3,3,3) -> points span from -1.5 to 1.5 on axis + mesh = trimesh.creation.box( + extents=[3.0, 3.0, 3.0], + ) + points = meshes.extract_points_select_points_over_axis( + mesh=mesh, axis="x", direction="lower", n=4 + ) + assert len(points) == 4 + assert all(points[:, 0] < 0.0) + + # Test 1.2: Select 10 points from the higher end of the y-axis + points = meshes.extract_points_select_points_over_axis( + mesh=mesh, axis="y", direction="higher", n=4 + ) + assert len(points) == 4 + assert all(points[:, 1] > 0.0) + + # Test 2: A sphere + # The sphere is centered at the origin and has a radius of 1.0 + mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) + sphere_n_vertices = len(mesh.vertices) + + # Select 10 points from the higher end of the z-axis + points = meshes.extract_points_select_points_over_axis( + mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2 + ) + assert len(points) == sphere_n_vertices // 2 + assert all(points[:, 2] >= 0.0)