Skip to content

Commit

Permalink
Merge pull request #260 from ami-iit/refactor_contact_models
Browse files Browse the repository at this point in the history
Refactor contact models and add tests
  • Loading branch information
diegoferigo authored Oct 17, 2024
2 parents c8d238f + 1e5d13f commit 92d3e49
Show file tree
Hide file tree
Showing 10 changed files with 614 additions and 265 deletions.
77 changes: 58 additions & 19 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ def collidable_point_kinematics(
the linear component of the mixed 6D frame velocity.
"""

from jaxsim.rbda import collidable_points

# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel(

W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
base_position=data.base_position(),
base_quaternion=data.base_orientation(dcm=False),
Expand Down Expand Up @@ -304,6 +303,15 @@ def in_contact(


def estimate_good_soft_contacts_parameters(
*args, **kwargs
) -> jaxsim.rbda.contacts.ContactParamsTypes:

msg = "This method is deprecated, please use `{}`."
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
return estimate_good_contact_parameters(*args, **kwargs)


def estimate_good_contact_parameters(
model: js.model.JaxSimModel,
*,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
Expand All @@ -312,14 +320,9 @@ def estimate_good_soft_contacts_parameters(
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
**kwargs,
) -> (
jaxsim.rbda.contacts.RelaxedRigidContactsParams
| jaxsim.rbda.contacts.RigidContactsParams
| jaxsim.rbda.contacts.SoftContactsParams
| jaxsim.rbda.contacts.ViscoElasticContactsParams
):
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good parameters for soft-like contact models.
Estimate good contact parameters.
Args:
model: The model to consider.
Expand All @@ -332,12 +335,19 @@ def estimate_good_soft_contacts_parameters(
max_penetration:
The maximum penetration allowed in steady state when the robot is
supported by the configured number of active collidable points.
kwargs:
Additional model-specific parameters passed to the builder method of
the parameters class.
Returns:
The estimated good soft contacts parameters.
The estimated good contacts parameters.
Note:
This is primarily a convenience function for soft-like contact models.
However, it provides with some good default parameters also for the other ones.
Note:
This method provides a good starting point for the soft contacts parameters.
This method provides a good set of contacts parameters.
The user is encouraged to fine-tune the parameters based on the
specific application.
"""
Expand All @@ -364,6 +374,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
max_δ = (
max_penetration
if max_penetration is not None
# Consider as default a 0.5% of the model height.
else 0.005 * estimate_model_height(model=model)
)

Expand All @@ -381,8 +392,11 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
**dict(
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)
| kwargs,
)

case contacts.ViscoElasticContacts():
Expand All @@ -396,15 +410,40 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
**kwargs,
**dict(
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)
| kwargs,
)
)

case contacts.RigidContacts():
assert isinstance(model.contact_model, contacts.RigidContacts)

# Disable Baumgarte stabilization by default since it does not play
# well with the forward Euler integrator.
K = kwargs.get("K", 0.0)

parameters = contacts.RigidContactsParams.build(
mu=static_friction_coefficient,
**dict(
K=K,
D=2 * jnp.sqrt(K),
)
| kwargs,
)

case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)

parameters = contacts.RelaxedRigidContactsParams.build(
mu=static_friction_coefficient,
**kwargs,
)

case _:
logging.warning("The active contact model is not soft-like, no-op.")
parameters = model.contact_model.parameters
raise ValueError(f"Invalid contact model: {model.contact_model}")

return parameters

Expand Down
5 changes: 3 additions & 2 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

state: ODEState

gravity: jtp.Array
gravity: jtp.Vector

contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)

Expand Down Expand Up @@ -224,7 +224,8 @@ def build(
jaxsim.rbda.contacts.SoftContacts
| jaxsim.rbda.contacts.ViscoElasticContacts,
):
contacts_params = js.contact.estimate_good_soft_contacts_parameters(

contacts_params = js.contact.estimate_good_contact_parameters(
model=model, standard_gravity=standard_gravity
)

Expand Down
20 changes: 8 additions & 12 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class JaxSimModel(JaxsimDataclass):
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
)

# Note that this is the default contact model.
# Its parameters, if any, are then overridden from those stored in JaxSimModelData.
contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
default=None, repr=False
)
Expand Down Expand Up @@ -2044,24 +2046,18 @@ def step(
M = js.model.free_floating_mass_matrix(model, data_tf)
W_p_C = js.contact.collidable_point_positions(model, data_tf)

# Compute the height of the terrain below each collidable point.
px, py, _ = W_p_C.T
terrain_height = jax.vmap(model.terrain.height)(px, py)

# Compute the contact state.
inactive_collidable_points, _ = (
jaxsim.rbda.contacts.RigidContacts.detect_contacts(
W_p_C=W_p_C,
terrain_height=terrain_height,
)
)
# Compute the penetration depth of the collidable points.
δ, *_ = jax.vmap(
jaxsim.rbda.contacts.common.compute_penetration_data,
in_axes=(0, 0, None),
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)

# Compute the impact velocity.
# It may be discontinuous in case new contacts are made.
BW_nu_post_impact = (
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
data=data_tf,
inactive_collidable_points=inactive_collidable_points,
inactive_collidable_points=(δ <= 0),
M=M,
J_WC=J_WC,
)
Expand Down
7 changes: 7 additions & 0 deletions src/jaxsim/rbda/contacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@
from .rigid import RigidContacts, RigidContactsParams
from .soft import SoftContacts, SoftContactsParams
from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams

ContactParamsTypes = (
SoftContactsParams
| RigidContactsParams
| RelaxedRigidContactsParams
| ViscoElasticContactsParams
)
52 changes: 49 additions & 3 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

import abc
import functools
from typing import Any

import jax
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.terrain
import jaxsim.typing as jtp
Expand All @@ -14,6 +18,47 @@
from typing_extensions import Self


@functools.partial(jax.jit, static_argnames=("terrain",))
def compute_penetration_data(
p: jtp.VectorLike,
v: jtp.VectorLike,
terrain: jaxsim.terrain.Terrain,
) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
"""
Compute the penetration data (depth, rate, and terrain normal) of a collidable point.
Args:
p: The position of the collidable point.
v:
The linear velocity of the point (linear component of the mixed 6D velocity
of the implicit frame `C = (W_p_C, [W])` associated to the point).
terrain: The considered terrain.
Returns:
A tuple containing the penetration depth, the penetration velocity,
and the considered terrain normal.
"""

# Pre-process the position and the linear velocity of the collidable point.
W_ṗ_C = jnp.array(v).squeeze()
px, py, pz = jnp.array(p).squeeze()

# Compute the terrain normal and the contact depth.
= terrain.normal(x=px, y=py).squeeze()
h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])

# Compute the penetration depth normal to the terrain.
δ = jnp.maximum(0.0, jnp.dot(h, ))

# Compute the penetration normal velocity.
δ_dot = -jnp.dot(W_ṗ_C, )

# Enforce the penetration rate to be zero when the penetration depth is zero.
δ_dot = jnp.where(δ > 0, δ_dot, 0.0)

return δ, δ_dot,


class ContactsParams(JaxsimDataclass):
"""
Abstract class representing the parameters of a contact model.
Expand Down Expand Up @@ -86,7 +131,7 @@ def compute_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
**kwargs,
) -> tuple[jtp.Vector, tuple[Any, ...]]:
) -> tuple[jtp.Matrix, tuple[Any, ...]]:
"""
Compute the contact forces.
Expand All @@ -95,8 +140,9 @@ def compute_contact_forces(
data: The data of the considered model.
Returns:
A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
and as second element a tuple of optional additional information.
A tuple containing as first element the computed 6D contact force applied to
the contact points and expressed in the world frame, and as second element
a tuple of optional additional information.
"""

pass
Expand Down
Loading

0 comments on commit 92d3e49

Please sign in to comment.