Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement new test for the simulation of a box created as mesh #290

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
4941ec3
Set env var when parsing from `robot-descriptions`
flferretti May 13, 2024
e76fe6c
Initial version of mesh support
lorycontixd May 13, 2024
a2c4109
Format and lint
flferretti May 14, 2024
078a3f8
Use already existing env var to solve mesh URIs
flferretti May 17, 2024
003cc71
Skip loading empty meshes
flferretti May 20, 2024
4332531
Added `networkx` as testing dependency
lorycontixd May 16, 2024
2e180c6
Address to reviews:
lorycontixd May 20, 2024
1bcf580
Added trimesh dependecy for conda-forge
lorycontixd May 21, 2024
9a596de
Moved mesh parsing logic inside mesh collision function
lorycontixd May 21, 2024
3d8ad9d
Implemented UniformSurfaceSampling for mesh point wrapping
lorycontixd May 21, 2024
67c97d2
Address reviews
lorycontixd May 21, 2024
0183c5b
Pre-commit
lorycontixd May 21, 2024
ff40470
Update `__eq__` magic and type hints
flferretti Jun 13, 2024
2c920ba
Removed unused lines in conftest
lorycontixd Jun 20, 2024
4bf1769
Fixed typo on logging message
lorycontixd Jun 20, 2024
136e42c
Removed unused import in conftest
lorycontixd Jun 20, 2024
8c46be3
Moved from jax.numpy to numpy in PlaneTerrain __eq__ magic to bypass …
lorycontixd Jul 5, 2024
71b0727
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
433602a
Implemented initial reviews from https://github.com/ami-iit/jaxsim/pu…
lorycontixd Jul 5, 2024
40e182d
First draft of new mesh wrapping algorithms
lorycontixd Jul 17, 2024
4b9e4b1
Implemented structure for new mesh wrapping algorithms
lorycontixd Jul 17, 2024
d761cc6
New mesh wrapping algorithms
lorycontixd Jul 17, 2024
d0579f5
New mesh wrapping algorithms with relative tests
lorycontixd Jul 18, 2024
1bc3b07
Renamed some parameters
lorycontixd Jul 18, 2024
94ecdb2
Restructured mesh mapping methods to follow inheritance
lorycontixd Jul 18, 2024
0f8503e
Run pre-commit
lorycontixd Jul 18, 2024
d9aabc7
Removed leftover parameters on create_mesh_collision
lorycontixd Jul 18, 2024
2c5b603
Removed wrong point selection & added logs
lorycontixd Jul 18, 2024
b9a49f3
Added string magics on wrapping methods
lorycontixd Jul 18, 2024
4570f1e
Added docstrings to mesh wrapping algorithms
lorycontixd Nov 14, 2024
932c4a4
Added string magics on wrapping methods
lorycontixd Jul 18, 2024
32e90b7
Implemented reviews
lorycontixd Nov 15, 2024
d417927
Fixed error on array sorting and relative test
lorycontixd Nov 15, 2024
9e64c05
Addressed reviews
lorycontixd Nov 15, 2024
29aa6f5
Updated variable names
lorycontixd Nov 15, 2024
57446b2
Added experimental feature warning for mesh parsing
lorycontixd Nov 15, 2024
9762d80
Merge branch 'add_mesh_support' of github.com:lorycontixd/jaxsim into…
lorycontixd Nov 15, 2024
fc251cf
Added int casting on mesh_enabled flag
lorycontixd Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pptree
- qpax
- rod >= 0.3.3
- trimesh
- typing_extensions # python<3.12
# ====================================
# Optional dependencies from setup.cfg
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ dependencies = [
"qpax",
"rod >= 0.3.3",
"typing_extensions ; python_version < '3.12'",
"trimesh",
"manifold3d",
]

[project.optional-dependencies]
Expand All @@ -67,7 +69,7 @@ testing = [
"idyntree >= 12.2.1",
"pytest >=6.0",
"pytest-icdiff",
"robot-descriptions",
"robot-descriptions"
]
viz = [
"lxml",
Expand Down
8 changes: 7 additions & 1 deletion src/jaxsim/parsers/descriptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
from .collision import (
BoxCollision,
CollidablePoint,
CollisionShape,
MeshCollision,
SphereCollision,
)
from .joint import JointDescription, JointGenericAxis, JointType
from .link import LinkDescription
from .model import ModelDescription
19 changes: 19 additions & 0 deletions src/jaxsim/parsers/descriptions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,22 @@ def __eq__(self, other: BoxCollision) -> bool:
return False

return hash(self) == hash(other)


@dataclasses.dataclass
class MeshCollision(CollisionShape):
center: jtp.VectorLike

def __hash__(self) -> int:
return hash(
(
hash(tuple(self.center.tolist())),
hash(self.collidable_points),
)
)

def __eq__(self, other: MeshCollision) -> bool:
if not isinstance(other, MeshCollision):
return False

return hash(self) == hash(other)
119 changes: 119 additions & 0 deletions src/jaxsim/parsers/rod/meshes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import numpy as np
import trimesh

VALID_AXIS = {"x": 0, "y": 1, "z": 2}


def parse_object_mapping_object(obj: trimesh.Trimesh | dict) -> trimesh.Trimesh:
if isinstance(obj, trimesh.Trimesh):
return obj
elif isinstance(obj, dict):
if obj["type"] == "box":
return trimesh.creation.box(extents=obj["extents"])
elif obj["type"] == "sphere":
return trimesh.creation.icosphere(subdivisions=4, radius=obj["radius"])
else:
raise ValueError(f"Invalid object type {obj['type']}")
else:
raise ValueError("Invalid object type")


def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:
"""
Extracts the vertices of a mesh as points.
"""
return mesh.vertices


def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray:
"""
Extracts N random points from the surface of a mesh.

Args:
mesh: The mesh from which to extract points.
n: The number of points to extract.

Returns:
The extracted points (N x 3 array).
"""

return mesh.sample(n)


def extract_points_uniform_surface_sampling(
mesh: trimesh.Trimesh, n: int
) -> np.ndarray:
"""
Extracts N uniformly sampled points from the surface of a mesh.

Args:
mesh: The mesh from which to extract points.
n: The number of points to extract.

Returns:
The extracted points (N x 3 array).
"""

return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0]


def extract_points_select_points_over_axis(
mesh: trimesh.Trimesh, axis: str, direction: str, n: int
) -> np.ndarray:
"""
Extracts N points from a mesh along a specified axis. The points are selected based on their position along the axis.

Args:
mesh: The mesh from which to extract points.
axis: The axis along which to extract points.
direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower".
n: The number of points to extract.

Returns:
The extracted points (N x 3 array).
"""

dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]}
arr = mesh.vertices

# Sort the array in ascending order
arr.sort(axis=0) # Sort rows lexicographically first, then columnar
sorted_arr = arr[dirs[direction]]
return sorted_arr


def extract_points_aap(
mesh: trimesh.Trimesh,
axis: str,
upper: float | None = None,
lower: float | None = None,
) -> np.ndarray:
"""
Extracts points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.

Args:
mesh: The mesh from which to extract points.
axis: The axis along which to extract points.
upper: The upper bound of the range.
lower: The lower bound of the range.

Returns:
The extracted points (N x 3 array).

Raises:
AssertionError: If the lower bound is greater than the upper bound.
"""

# Check bounds
upper = upper if upper is not None else np.inf
lower = lower if lower is not None else -np.inf
assert lower < upper, "Invalid bounds for axis-aligned plane"

# Logic

points = mesh.vertices[
(mesh.vertices[:, VALID_AXIS[axis]] >= lower)
& (mesh.vertices[:, VALID_AXIS[axis]] <= upper)
]

return points
12 changes: 12 additions & 0 deletions src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,18 @@ def extract_model_data(

collisions.append(sphere_collision)

if collision.geometry.mesh is not None and int(
os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0")
):
logging.warning("Mesh collision support is still experimental.")
mesh_collision = utils.create_mesh_collision(
collision=collision,
link_description=links_dict[link.name],
method=utils.meshes.extract_points_vertices,
)

collisions.append(mesh_collision)

return SDFData(
model_name=sdf_model.name,
link_descriptions=links,
Expand Down
51 changes: 51 additions & 0 deletions src/jaxsim/parsers/rod/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import os
import pathlib
from collections.abc import Callable
from typing import TypeVar

import numpy as np
import numpy.typing as npt
import rod
import trimesh
from rod.utils.resolve_uris import resolve_local_uri

import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Inertia
from jaxsim.parsers import descriptions
from jaxsim.parsers.rod import meshes

MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray])


def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
Expand Down Expand Up @@ -202,3 +211,45 @@ def fibonacci_sphere(samples: int) -> npt.NDArray:
return descriptions.SphereCollision(
collidable_points=collidable_points, center=center_wrt_link
)


def create_mesh_collision(
collision: rod.Collision,
link_description: descriptions.LinkDescription,
method: MeshMappingMethod = None,
) -> descriptions.MeshCollision:

file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
_file_type = file.suffix.replace(".", "")
mesh = trimesh.load_mesh(file, file_type=_file_type)

if mesh.is_empty:
raise RuntimeError(f"Failed to process '{file}' with trimesh")

mesh.apply_scale(collision.geometry.mesh.scale)
logging.info(
msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{_file_type}'"
)

if method is None:
method = meshes.VertexExtraction()
logging.debug("Using default Vertex Extraction method for mesh wrapping")
else:
logging.debug(f"Using method {method} for mesh wrapping")

points = method(mesh=mesh)
logging.debug(f"Extracted {len(points)} points from mesh")

W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4)
# Extract translation from transformation matrix
W_p_L = H[:3, 3]
mesh_points_wrt_link = points @ H[:3, :3].T + W_p_L
collidable_points = [
descriptions.CollidablePoint(
parent_link=link_description,
position=point,
enabled=True,
)
for point in mesh_points_wrt_link
]
return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L)
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,14 @@ def jaxsim_model_ur10() -> js.model.JaxSimModel:
"""

import robot_descriptions.ur10_description
from robot_descriptions._package_dirs import get_package_dirs

model_urdf_path = pathlib.Path(robot_descriptions.ur10_description.URDF_PATH)

os.environ["GAZEBO_MODEL_PATH"] = os.environ.get(
"GAZEBO_MODEL_PATH", ""
) + ":".join(get_package_dirs(robot_descriptions.ur10_description))

return build_jaxsim_model(model_description=model_urdf_path)


Expand Down
97 changes: 97 additions & 0 deletions tests/test_meshes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import trimesh

from jaxsim.parsers.rod import meshes


def test_mesh_wrapping_vertex_extraction():
"""Test the vertex extraction method on different meshes.
1. A simple box
2. A sphere
"""

# Test 1: A simple box
# First, create a box with origin at (0,0,0) and extents (3,3,3) -> points span from -1.5 to 1.5 on axis
mesh = trimesh.creation.box(
extents=[3.0, 3.0, 3.0],
)
points = meshes.extract_points_vertices(mesh=mesh)
assert len(points) == len(mesh.vertices)

# Test 2: A sphere
# The sphere is centered at the origin and has a radius of 1.0
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
points = meshes.extract_points_vertices(mesh=mesh)
assert len(points) == len(mesh.vertices)


def test_mesh_wrapping_aap():
"""Test the AAP wrapping method on different meshes.
1. A simple box
1.1: Remove all points above x=0.0
1.2: Remove all points below y=0.0
2. A sphere
"""

# Test 1.1: Remove all points above x=0.0
# The expected result is that the number of points is halved
# First, create a box with origin at (0,0,0) and extents (3,3,3) -> points span from -1.5 to 1.5 on axis
mesh = trimesh.creation.box(
extents=[3.0, 3.0, 3.0],
)
points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0)
assert len(points) == len(mesh.vertices) // 2
assert all(points[:, 0] > 0.0)

# Test 1.2: Remove all points below y=0.0
# Again, the expected result is that the number of points is halved
points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0)
assert len(points) == len(mesh.vertices) // 2
assert all(points[:, 1] < 0.0)

# Test 2: A sphere
# The sphere is centered at the origin and has a radius of 1.0. Points are expected to be halved
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
# Remove all points above y=0.0
points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0)
assert all(points[:, 1] >= 0.0)
assert len(points) < len(mesh.vertices)


def test_mesh_wrapping_points_over_axis():
"""
Test the points over axis method on different meshes.
1. A simple box
1.1: Select 10 points from the lower end of the x-axis
1.2: Select 10 points from the higher end of the y-axis
2. A sphere
"""

# Test 1.1: Remove 10 points from the lower end of the x-axis
# First, create a box with origin at (0,0,0) and extents (3,3,3) -> points span from -1.5 to 1.5 on axis
mesh = trimesh.creation.box(
extents=[3.0, 3.0, 3.0],
)
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="x", direction="lower", n=4
)
assert len(points) == 4
assert all(points[:, 0] < 0.0)

# Test 1.2: Select 10 points from the higher end of the y-axis
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="y", direction="higher", n=4
)
assert len(points) == 4
assert all(points[:, 1] > 0.0)

# Test 2: A sphere
# The sphere is centered at the origin and has a radius of 1.0
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
sphere_n_vertices = len(mesh.vertices)

# Select 10 points from the higher end of the z-axis
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2
)
assert len(points) == sphere_n_vertices // 2
assert all(points[:, 2] >= 0.0)
Loading