From bc1b77a42b6ecb775f4942957d90e5ee94c5ac60 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 10 Oct 2024 17:24:24 +0200 Subject: [PATCH 01/13] Share the `compute_penetration_data` function among contact models --- src/jaxsim/api/model.py | 18 ++---- src/jaxsim/rbda/contacts/common.py | 45 ++++++++++++++ src/jaxsim/rbda/contacts/rigid.py | 95 +++++++++--------------------- src/jaxsim/rbda/contacts/soft.py | 37 ++---------- 4 files changed, 84 insertions(+), 111 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 39192f4a4..3621c0c59 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2044,24 +2044,18 @@ def step( M = js.model.free_floating_mass_matrix(model, data_tf) W_p_C = js.contact.collidable_point_positions(model, data_tf) - # Compute the height of the terrain below each collidable point. - px, py, _ = W_p_C.T - terrain_height = jax.vmap(model.terrain.height)(px, py) - - # Compute the contact state. - inactive_collidable_points, _ = ( - jaxsim.rbda.contacts.RigidContacts.detect_contacts( - W_p_C=W_p_C, - terrain_height=terrain_height, - ) - ) + # Compute the penetration depth of the collidable points. + δ, *_ = jax.vmap( + jaxsim.rbda.contacts.common.compute_penetration_data, + in_axes=(0, 0, None), + )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) # Compute the impact velocity. # It may be discontinuous in case new contacts are made. BW_nu_post_impact = ( jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity( data=data_tf, - inactive_collidable_points=inactive_collidable_points, + inactive_collidable_points=(δ <= 0), M=M, J_WC=J_WC, ) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 7639a9738..5a2a6648a 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -1,8 +1,12 @@ from __future__ import annotations import abc +import functools from typing import Any +import jax +import jax.numpy as jnp + import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp @@ -14,6 +18,47 @@ from typing_extensions import Self +@functools.partial(jax.jit, static_argnames=("terrain",)) +def compute_penetration_data( + p: jtp.VectorLike, + v: jtp.VectorLike, + terrain: jaxsim.terrain.Terrain, +) -> tuple[jtp.Float, jtp.Float, jtp.Vector]: + """ + Compute the penetration data (depth, rate, and terrain normal) of a collidable point. + + Args: + p: The position of the collidable point. + v: + The linear velocity of the point (linear component of the mixed 6D velocity + of the implicit frame `C = (W_p_C, [W])` associated to the point). + terrain: The considered terrain. + + Returns: + A tuple containing the penetration depth, the penetration velocity, + and the considered terrain normal. + """ + + # Pre-process the position and the linear velocity of the collidable point. + W_ṗ_C = jnp.array(v).squeeze() + px, py, pz = jnp.array(p).squeeze() + + # Compute the terrain normal and the contact depth. + n̂ = terrain.normal(x=px, y=py).squeeze() + h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz]) + + # Compute the penetration depth normal to the terrain. + δ = jnp.maximum(0.0, jnp.dot(h, n̂)) + + # Compute the penetration normal velocity. + δ_dot = -jnp.dot(W_ṗ_C, n̂) + + # Enforce the penetration rate to be zero when the penetration depth is zero. + δ_dot = jnp.where(δ > 0, δ_dot, 0.0) + + return δ, δ_dot, n̂ + + class ContactsParams(JaxsimDataclass): """ Abstract class representing the parameters of a contact model. diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index bfacba19a..6c3dd0f67 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -13,6 +13,7 @@ from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr from jaxsim.terrain import FlatTerrain, Terrain +from . import common from .common import ContactModel, ContactsParams try: @@ -170,46 +171,6 @@ def build( _solver_options_values=tuple(solver_options.values()), ) - @staticmethod - def detect_contacts( - W_p_C: jtp.ArrayLike, - terrain_height: jtp.ArrayLike, - ) -> tuple[jtp.Vector, jtp.Vector]: - """ - Detect contacts between the collidable points and the terrain. - - Args: - W_p_C: The position of the collidable points. - terrain_height: The height of the terrain at the collidable point position. - - Returns: - A tuple containing the activation state of the collidable points - and the contact penetration depth h. - """ - - # TODO: reduce code duplication with js.contact.in_contact - def detect_contact( - W_p_C: jtp.ArrayLike, - terrain_height: jtp.FloatLike, - ) -> tuple[jtp.Bool, jtp.Float]: - """ - Detect contacts between the collidable points and the terrain. - """ - - # Unpack the position of the collidable point. - _, _, pz = W_p_C.squeeze() - - inactive = pz > terrain_height - - # Compute contact penetration depth - h = jnp.maximum(0.0, terrain_height - pz) - - return inactive, h - - inactive_collidable_points, h = jax.vmap(detect_contact)(W_p_C, terrain_height) - - return inactive_collidable_points, h - @staticmethod def compute_impact_velocity( inactive_collidable_points: jtp.ArrayLike, @@ -332,13 +293,13 @@ def compute_contact_forces( model=model, data=data ) - terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1]) - n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0] + # Get the number of collidable points. + n_collidable_points = len(model.kin_dyn_parameters.contact_parameters.body) - # Compute the activation state of the collidable points - inactive_collidable_points, h = RigidContacts.detect_contacts( - W_p_C=position, - terrain_height=terrain_height, + # Compute the penetration depth and velocity of the collidable points. + # Note that this function considers the penetration in the normal direction. + δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( + position, velocity, self.terrain ) # Compute the Delassus matrix. @@ -379,12 +340,12 @@ def compute_contact_forces( CW_J_dot_WC_BW=J̇_WC_BW, ).flatten() - # Compute stabilization term - ḣ = velocity[:, 2].squeeze() + # Compute stabilization term. baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term( - inactive_collidable_points=inactive_collidable_points, - h=h, - ḣ=ḣ, + inactive_collidable_points=(δ <= 0), + δ=δ, + δ_dot=δ_dot, + n=n̂, K=self.parameters.K, D=self.parameters.D, ).flatten() @@ -395,7 +356,7 @@ def compute_contact_forces( Q = delassus_matrix q = free_contact_acc G = RigidContacts._compute_ineq_constraint_matrix( - inactive_collidable_points=inactive_collidable_points, mu=self.parameters.mu + inactive_collidable_points=(δ <= 0), mu=self.parameters.mu ) h_bounds = RigidContacts._compute_ineq_bounds( n_collidable_points=n_collidable_points @@ -497,33 +458,35 @@ def _linear_acceleration_of_collidable_points( @staticmethod def _compute_baumgarte_stabilization_term( inactive_collidable_points: jtp.ArrayLike, - h: jtp.ArrayLike, - ḣ: jtp.ArrayLike, + δ: jtp.ArrayLike, + δ_dot: jtp.ArrayLike, + n: jtp.ArrayLike, K: jtp.FloatLike, D: jtp.FloatLike, ) -> jtp.Array: + def baumgarte_stabilization( inactive: jtp.BoolLike, - h: jtp.FloatLike, - ḣ: jtp.FloatLike, + δ: jtp.FloatLike, + δ_dot: jtp.FloatLike, + n: jtp.ArrayLike, k_baumgarte: jtp.FloatLike, d_baumgarte: jtp.FloatLike, ) -> jtp.Array: + baumgarte_term = jax.lax.cond( inactive, - lambda h, ḣ, K, D: jnp.zeros(shape=(3,)), - lambda h, ḣ, K, D: jnp.zeros(shape=(3,)).at[2].set(K * h + D * ḣ), - *( - h, - ḣ, - k_baumgarte, - d_baumgarte, - ), + lambda δ, δ_dot, n, K, D: jnp.zeros(3), + # This is equivalent to: K*(pT - p)⋅n̂ + D*(0 - v)⋅n̂, + # where pT is the point on the terrain surface vertical to p. + lambda δ, δ_dot, n, K, D: (K * δ + D * δ_dot) * n, + *(δ, δ_dot, n, k_baumgarte, d_baumgarte), ) + return baumgarte_term baumgarte_term = jax.vmap( - baumgarte_stabilization, in_axes=(0, 0, 0, None, None) - )(inactive_collidable_points, h, ḣ, K, D) + baumgarte_stabilization, in_axes=(0, 0, 0, 0, None, None) + )(inactive_collidable_points, δ, δ_dot, n, K, D) return baumgarte_term diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 7019344d1..7295d5743 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -14,7 +14,7 @@ from jaxsim.math import StandardGravity from jaxsim.terrain import FlatTerrain, Terrain -from .common import ContactModel, ContactsParams +from . import common try: from typing import Self @@ -23,7 +23,7 @@ @jax_dataclasses.pytree_dataclass -class SoftContactsParams(ContactsParams): +class SoftContactsParams(common.ContactsParams): """Parameters of the soft contacts model.""" K: jtp.Float = dataclasses.field( @@ -189,7 +189,7 @@ def valid(self) -> jtp.BoolLike: @jax_dataclasses.pytree_dataclass -class SoftContacts(ContactModel): +class SoftContacts(common.ContactModel): """Soft contacts model.""" parameters: SoftContactsParams = dataclasses.field( @@ -277,9 +277,7 @@ def hunt_crossley_contact_model( μ = mu # Compute the penetration depth, its rate, and the considered terrain normal. - δ, δ̇, n̂ = SoftContacts.compute_penetration_data( - p=W_p_C, v=W_ṗ_C, terrain=terrain - ) + δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain) # There are few operations like computing the norm of a vector with zero length # or computing the square root of zero that are problematic in an AD context. @@ -450,30 +448,3 @@ def compute_contact_forces( )(W_p_C, W_ṗ_C, m) return W_f, (ṁ,) - - @staticmethod - @jax.jit - def compute_penetration_data( - p: jtp.VectorLike, - v: jtp.VectorLike, - terrain: jaxsim.terrain.Terrain, - ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]: - - # Pre-process the position and the linear velocity of the collidable point. - W_ṗ_C = jnp.array(v).squeeze() - px, py, pz = jnp.array(p).squeeze() - - # Compute the terrain normal and the contact depth. - n̂ = terrain.normal(x=px, y=py).squeeze() - h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz]) - - # Compute the penetration depth normal to the terrain. - δ = jnp.maximum(0.0, jnp.dot(h, n̂)) - - # Compute the penetration normal velocity. - δ̇ = -jnp.dot(W_ṗ_C, n̂) - - # Enforce the penetration rate to be zero when the penetration depth is zero. - δ̇ = jnp.where(δ > 0, δ̇, 0.0) - - return δ, δ̇, n̂ From 4ebadf9819865cc5de1edbf95e120eb4a84246e7 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 17 Oct 2024 11:18:29 +0200 Subject: [PATCH 02/13] Refactor rigid contact model --- src/jaxsim/rbda/contacts/rigid.py | 107 +++++++++++++++++------------- 1 file changed, 61 insertions(+), 46 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 6c3dd0f67..63418ef06 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -264,28 +264,32 @@ def compute_contact_forces( # contact parameters are not compatible. model, data = self.initialize_model_and_data(model=model, data=data) - # Import qpax just in this method + # Import qpax privately just in this method. import qpax - link_forces = ( - link_forces + link_forces = jnp.atleast_2d( + jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ) - joint_force_references = ( - joint_force_references + joint_force_references = jnp.atleast_1d( + jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None else jnp.zeros((model.number_of_joints(),)) ) - # Compute kin-dyn quantities used in the contact model + # Compute kin-dyn quantities used in the contact model. with data.switch_velocity_representation(VelRepr.Mixed): + + BW_ν = data.generalized_velocity() + M = js.model.free_floating_mass_matrix(model=model, data=data) + J_WC = js.contact.jacobian(model=model, data=data) + J̇_WC = js.contact.jacobian_derivative(model=model, data=data) + W_H_C = js.contact.transforms(model=model, data=data) - J̇_WC_BW = js.contact.jacobian_derivative(model=model, data=data) - BW_ν = data.generalized_velocity() # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. @@ -302,14 +306,7 @@ def compute_contact_forces( position, velocity, self.terrain ) - # Compute the Delassus matrix. - delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC) - - # Add regularization for better numerical conditioning. - delassus_matrix = delassus_matrix + self.regularization_delassus * jnp.eye( - delassus_matrix.shape[0] - ) - + # Build a references object to simplify converting link forces. references = js.references.JaxSimModelReferences.build( model=model, data=data, @@ -318,10 +315,12 @@ def compute_contact_forces( joint_force_references=joint_force_references, ) + # Compute the generalized free acceleration. with ( references.switch_velocity_representation(VelRepr.Mixed), data.switch_velocity_representation(VelRepr.Mixed), ): + BW_ν̇_free = jnp.hstack( js.ode.system_acceleration( model=model, @@ -333,11 +332,13 @@ def compute_contact_forces( ) ) + # Compute the free linear acceleration of the collidable points. + # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points( BW_nu=BW_ν, BW_nu_dot=BW_ν̇_free, CW_J_WC_BW=J_WC, - CW_J_dot_WC_BW=J̇_WC_BW, + CW_J_dot_WC_BW=J̇_WC, ).flatten() # Compute stabilization term. @@ -350,47 +351,55 @@ def compute_contact_forces( D=self.parameters.D, ).flatten() - free_contact_acc -= baumgarte_term + # Compute the Delassus matrix. + delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC) + + # Initialize regularization term of the Delassus matrix for + # better numerical conditioning. + Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0]) + + # Construct the quadratic cost function. + Q = delassus_matrix + Iε + q = free_contact_acc - baumgarte_term - # Setup optimization problem - Q = delassus_matrix - q = free_contact_acc + # Construct the inequality constraints. G = RigidContacts._compute_ineq_constraint_matrix( inactive_collidable_points=(δ <= 0), mu=self.parameters.mu ) h_bounds = RigidContacts._compute_ineq_bounds( n_collidable_points=n_collidable_points ) + + # Construct the equality constraints. A = jnp.zeros((0, 3 * n_collidable_points)) b = jnp.zeros((0,)) - # Solve the optimization problem - solution, *_ = qpax.solve_qp( + # Solve the following optimization problem with qpax: + # + # min_{x} 0.5 x⊤ Q x + q⊤ x + # + # s.t. A x = b + # G x ≤ h + # + # TODO: add possibility to notify if the QP problem did not converge. + solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841 Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options ) - f_C_lin = solution.reshape(-1, 3) - - # Transform linear contact forces to 6D - CW_f_C = jnp.hstack( - ( - f_C_lin, - jnp.zeros((f_C_lin.shape[0], 3)), - ) - ) + # Reshape the optimized solution to be a matrix of 3D contact forces. + CW_fl_C = solution.reshape(-1, 3) - # Transform the contact forces to inertial-fixed representation + # Convert the contact forces from mixed to inertial-fixed representation. W_f_C = jax.vmap( - lambda CW_f_C, W_H_C: ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=CW_f_C, - transform=W_H_C, - other_representation=VelRepr.Mixed, - is_force=True, + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + transform=W_H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ) ), - )( - CW_f_C, - W_H_C, - ) + )(CW_fl_C, W_H_C) return W_f_C, () @@ -399,6 +408,7 @@ def _delassus_matrix( M: jtp.MatrixLike, J_WC: jtp.MatrixLike, ) -> jtp.Matrix: + sl = jnp.s_[:, 0:3, :] J_WC_lin = jnp.vstack(J_WC[sl]) @@ -409,6 +419,7 @@ def _delassus_matrix( def _compute_ineq_constraint_matrix( inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike ) -> jtp.Matrix: + def compute_G_single_point(mu: float, c: float) -> jtp.Matrix: """ Compute the inequality constraint matrix for a single collidable point @@ -436,6 +447,7 @@ def compute_G_single_point(mu: float, c: float) -> jtp.Matrix: @staticmethod def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector: + n_constraints = 6 * n_collidable_points return jnp.zeros(shape=(n_constraints,)) @@ -446,13 +458,16 @@ def _linear_acceleration_of_collidable_points( CW_J_WC_BW: jtp.MatrixLike, CW_J_dot_WC_BW: jtp.MatrixLike, ) -> jtp.Matrix: - CW_J̇_WC_BW = CW_J_dot_WC_BW + BW_ν = BW_nu BW_ν̇ = BW_nu_dot + CW_J̇_WC_BW = CW_J_dot_WC_BW + # Compute the linear acceleration of the collidable points. + # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ - CW_a_WC = CW_a_WC.reshape(-1, 6) + CW_a_WC = CW_a_WC.reshape(-1, 6) return CW_a_WC[:, 0:3].squeeze() @staticmethod @@ -465,7 +480,7 @@ def _compute_baumgarte_stabilization_term( D: jtp.FloatLike, ) -> jtp.Array: - def baumgarte_stabilization( + def baumgarte_stabilization_of_single_point( inactive: jtp.BoolLike, δ: jtp.FloatLike, δ_dot: jtp.FloatLike, @@ -486,7 +501,7 @@ def baumgarte_stabilization( return baumgarte_term baumgarte_term = jax.vmap( - baumgarte_stabilization, in_axes=(0, 0, 0, 0, None, None) + baumgarte_stabilization_of_single_point, in_axes=(0, 0, 0, 0, None, None) )(inactive_collidable_points, δ, δ_dot, n, K, D) return baumgarte_term From 2d8af318c1fb9a22ee23ab083ab4c4a305b756ad Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 8 Oct 2024 16:01:39 +0200 Subject: [PATCH 03/13] Fix estimation of good K/D for generic Hunt/Crossley models --- src/jaxsim/rbda/contacts/soft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 7295d5743..ff5427a4f 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -161,7 +161,9 @@ def build_default_from_jaxsim_model( f_average = m * g / number_of_active_collidable_points_steady_state # Compute the stiffness to get the desired steady-state penetration. - K = f_average / jnp.power(δ_max, 3 / 2) + # Note that this is dependent on the non-linear exponent used in + # the damping term of the Hunt/Crossley model. + K = f_average / jnp.power(δ_max, 1 + p) # Compute the damping using the damping ratio. critical_damping = 2 * jnp.sqrt(K * m) From cd1584de0480de8e63bcd119054acfb0fa560a2e Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 9 Oct 2024 14:42:12 +0200 Subject: [PATCH 04/13] Get the right contact parameters from data Fixes a regression that prevented overriding the nominal parameters stored in model with those stored in data --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 5 +++-- src/jaxsim/rbda/contacts/rigid.py | 9 +++++---- src/jaxsim/rbda/contacts/soft.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 6737a14b3..fefb0b003 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -250,6 +250,7 @@ def compute_contact_forces( # This will raise an exception if either the contact model or the # contact parameters are not compatible. model, data = self.initialize_model_and_data(model=model, data=data) + assert isinstance(data.contacts_params, RelaxedRigidContactsParams) link_forces = ( link_forces @@ -274,7 +275,7 @@ def compute_contact_forces( def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: x, y, z = jax.tree.map(jnp.squeeze, (x, y, z)) - n̂ = self.terrain.normal(x=x, y=y).squeeze() + n̂ = model.terrain.normal(x=x, y=y).squeeze() h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)]) return jnp.dot(h, n̂) @@ -320,7 +321,7 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: model=model, penetration=δ, velocity=velocity, - parameters=self.parameters, + parameters=data.contacts_params, ) G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0] diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 63418ef06..8e40c1c1a 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -263,6 +263,7 @@ def compute_contact_forces( # This will raise an exception if either the contact model or the # contact parameters are not compatible. model, data = self.initialize_model_and_data(model=model, data=data) + assert isinstance(data.contacts_params, RigidContactsParams) # Import qpax privately just in this method. import qpax @@ -303,7 +304,7 @@ def compute_contact_forces( # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( - position, velocity, self.terrain + position, velocity, model.terrain ) # Build a references object to simplify converting link forces. @@ -347,8 +348,8 @@ def compute_contact_forces( δ=δ, δ_dot=δ_dot, n=n̂, - K=self.parameters.K, - D=self.parameters.D, + K=data.contacts_params.K, + D=data.contacts_params.D, ).flatten() # Compute the Delassus matrix. @@ -364,7 +365,7 @@ def compute_contact_forces( # Construct the inequality constraints. G = RigidContacts._compute_ineq_constraint_matrix( - inactive_collidable_points=(δ <= 0), mu=self.parameters.mu + inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu ) h_bounds = RigidContacts._compute_ineq_bounds( n_collidable_points=n_collidable_points diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index ff5427a4f..8b626bc00 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -444,8 +444,8 @@ def compute_contact_forces( position=p, velocity=v, tangential_deformation=m, - parameters=self.parameters, - terrain=self.terrain, + parameters=data.contacts_params, + terrain=model.terrain, ) )(W_p_C, W_ṗ_C, m) From 411416194a58d6955f872c3f2b4cbf6f6c01ee24 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 17 Oct 2024 11:18:39 +0200 Subject: [PATCH 05/13] Refactor relaxed-rigid contacts --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 202 ++++++++++++++-------- 1 file changed, 131 insertions(+), 71 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index fefb0b003..bc2e46ecb 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -12,11 +12,10 @@ import jaxsim.api as js import jaxsim.typing as jtp from jaxsim import logging -from jaxsim.api.common import VelRepr -from jaxsim.math import Adjoint +from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr from jaxsim.terrain.terrain import FlatTerrain, Terrain -from .common import ContactModel, ContactsParams +from . import common try: from typing import Self @@ -25,7 +24,7 @@ @jax_dataclasses.pytree_dataclass -class RelaxedRigidContactsParams(ContactsParams): +class RelaxedRigidContactsParams(common.ContactsParams): """Parameters of the relaxed rigid contacts model.""" # Time constant @@ -116,14 +115,24 @@ def build( ) -> Self: """Create a `RelaxedRigidContactsParams` instance""" + def default(name: str): + return cls.__dataclass_fields__[name].default_factory() + return cls( - **{ - field: jnp.array(locals().get(field, default), dtype=default.dtype) - for field, default in map( - lambda f: (f, cls.__dataclass_fields__[f].default), - filter(lambda f: f != "__mutability__", cls.__dataclass_fields__), - ) - } + time_constant=jnp.array( + time_constant or default("time_constant"), dtype=float + ), + damping_coefficient=jnp.array( + damping_coefficient or default("damping_coefficient"), dtype=float + ), + d_min=jnp.array(d_min or default("d_min"), dtype=float), + d_max=jnp.array(d_max or default("d_max"), dtype=float), + width=jnp.array(width or default("width"), dtype=float), + midpoint=jnp.array(midpoint or default("midpoint"), dtype=float), + power=jnp.array(power or default("power"), dtype=float), + stiffness=jnp.array(stiffness or default("stiffness"), dtype=float), + damping=jnp.array(damping or default("damping"), dtype=float), + mu=jnp.array(mu or default("mu"), dtype=float), ) def valid(self) -> jtp.BoolLike: @@ -142,7 +151,7 @@ def valid(self) -> jtp.BoolLike: @jax_dataclasses.pytree_dataclass -class RelaxedRigidContacts(ContactModel): +class RelaxedRigidContacts(common.ContactModel): """Relaxed rigid contacts model.""" parameters: RelaxedRigidContactsParams = dataclasses.field( @@ -252,17 +261,17 @@ def compute_contact_forces( model, data = self.initialize_model_and_data(model=model, data=data) assert isinstance(data.contacts_params, RelaxedRigidContactsParams) - link_forces = ( - link_forces + link_forces = jnp.atleast_2d( + jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) - ) + ).astype(float) - joint_force_references = ( - joint_force_references + joint_force_references = jnp.atleast_1d( + jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None else jnp.zeros(model.number_of_joints()) - ) + ).astype(float) references = js.references.JaxSimModelReferences.build( model=model, @@ -272,7 +281,7 @@ def compute_contact_forces( joint_force_references=joint_force_references, ) - def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: + def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: x, y, z = jax.tree.map(jnp.squeeze, (x, y, z)) n̂ = model.terrain.normal(x=x, y=y).squeeze() @@ -287,19 +296,19 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: ) # Compute the activation state of the collidable points - δ = jax.vmap(_detect_contact)(*position.T) + δ = jax.vmap(detect_contact)(*position.T) + + # Compute the transforms of the implicit frames corresponding to the + # collidable points. + W_H_C = js.contact.transforms(model=model, data=data) with ( references.switch_velocity_representation(VelRepr.Mixed), data.switch_velocity_representation(VelRepr.Mixed), ): - M = js.model.free_floating_mass_matrix(model=model, data=data) - Jl_WC = jnp.vstack( - jax.vmap(lambda J, height: J * (height < 0))( - js.contact.jacobian(model=model, data=data)[:, :3, :], δ - ) - ) - W_H_C = js.contact.transforms(model=model, data=data) + + BW_ν = data.generalized_velocity() + BW_ν̇_free = jnp.hstack( js.ode.system_acceleration( model=model, @@ -310,20 +319,31 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: ), ) ) - BW_ν = data.generalized_velocity() + + M = js.model.free_floating_mass_matrix(model=model, data=data) + + Jl_WC = jnp.vstack( + jax.vmap(lambda J, height: J * (height < 0))( + js.contact.jacobian(model=model, data=data)[:, :3, :], δ + ) + ) + J̇_WC = jnp.vstack( jax.vmap(lambda J̇, height: J̇ * (height < 0))( js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ ), ) - a_ref, R, K, D = self._regularizers( - model=model, - penetration=δ, - velocity=velocity, - parameters=data.contacts_params, - ) + # Compute the regularization terms. + a_ref, R, K, D = self._regularizers( + model=model, + penetration=δ, + velocity=velocity, + parameters=data.contacts_params, + ) + # Compute the Delassus matrix and the free mixed linear acceleration of + # the collidable points. G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0] CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν @@ -331,26 +351,40 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: A = G + R b = CW_al_free_WC - a_ref + # Create the objective function to minimize as a lambda computing the cost + # from the optimized variables x. objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b)) + # ======================================== + # Helper function to run the L-BFGS solver + # ======================================== + def run_optimization( - init_params: jtp.Array, + init_params: jtp.Vector, fun: Callable, - opt: optax.GradientTransformation, - maxiter: jtp.Int, - tol: jtp.Float, - **kwargs, - ): + opt: optax.GradientTransformationExtraArgs, + maxiter: int, + tol: float, + ) -> tuple[jtp.Vector, optax.OptState]: + + # Get the function to compute the loss and the gradient w.r.t. its inputs. value_and_grad_fn = optax.value_and_grad_from_state(fun) - def step(carry): + # Initialize the carry of the following loop. + OptimizationCarry = tuple[jtp.Vector, optax.OptState] + init_carry: OptimizationCarry = (init_params, opt.init(params=init_params)) + + def step(carry: OptimizationCarry) -> OptimizationCarry: + params, state = carry + value, grad = value_and_grad_fn( params, state=state, A=A, b=b, ) + updates, state = opt.update( updates=grad, state=state, @@ -361,22 +395,32 @@ def step(carry): A=A, b=b, ) + params = optax.apply_updates(params, updates) + return params, state - def continuing_criterion(carry): + def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: + _, state = carry + iter_num = optax.tree_utils.tree_get(state, "count") grad = optax.tree_utils.tree_get(state, "grad") err = optax.tree_utils.tree_l2_norm(grad) + return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol)) - init_carry = (init_params, opt.init(init_params)) final_params, final_state = jax.lax.while_loop( continuing_criterion, step, init_carry ) + return final_params, final_state + # ====================================== + # Compute the contact forces with L-BFGS + # ====================================== + + # Initialize the optimized forces with a linear Hunt/Crossley model. init_params = ( K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ) + D[:, jnp.newaxis] * velocity @@ -391,28 +435,30 @@ def continuing_criterion(carry): maxiter = solver_options.pop("maxiter") # Compute the 3D linear force in C[W] frame. - CW_f_Ci, _ = run_optimization( + solution, _ = run_optimization( init_params=init_params, - A=A, - b=b, - maxiter=maxiter, - opt=optax.lbfgs(**solver_options), fun=objective, + opt=optax.lbfgs(**solver_options), tol=tol, + maxiter=maxiter, ) - CW_f_Ci = CW_f_Ci.reshape((-1, 3)) - - def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array: - W_Xf_CW = Adjoint.from_transform( - W_H_C.at[0:3, 0:3].set(jnp.eye(3)), - inverse=True, - ).T - return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)]) - - W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci) + # Reshape the optimized solution to be a matrix of 3D contact forces. + CW_fl_C = solution.reshape(-1, 3) + + # Convert the contact forces from mixed to inertial-fixed representation. + W_f_C = jax.vmap( + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + transform=W_H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ) + ), + )(CW_fl_C, W_H_C) - return W_f_C, (None,) + return W_f_C, () @staticmethod def _regularizers( @@ -434,13 +480,28 @@ def _regularizers( A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping. """ - Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple( - parameters + # Extract the parameters of the contact model. + Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = ( + getattr(parameters, field) + for field in ( + "time_constant", + "damping_coefficient", + "d_min", + "d_max", + "width", + "midpoint", + "power", + "stiffness", + "damping", + "mu", + ) ) - def _imp_aref( - penetration: jtp.Array, - velocity: jtp.Array, + # Compute the 6D inertia matrices of all links. + M_L = js.model.link_spatial_inertia_matrices(model=model) + + def imp_aref( + penetration: jtp.Array, velocity: jtp.Array ) -> tuple[jtp.Array, jtp.Array]: """ Calculates impedance and offset acceleration in constraint frame. @@ -475,7 +536,7 @@ def _imp_aref( return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f) - def _compute_row( + def compute_row( *, link_idx: jtp.Float, penetration: jtp.Array, @@ -483,7 +544,7 @@ def _compute_row( ) -> tuple[jtp.Array, jtp.Array]: # Compute the reference acceleration. - ξ, a_ref, K, D = _imp_aref( + ξ, a_ref, K, D = imp_aref( penetration=penetration, velocity=velocity, ) @@ -497,12 +558,10 @@ def _compute_row( return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D)) - M_L = js.model.link_spatial_inertia_matrices(model=model) - a_ref, R, K, D = jax.tree.map( - jnp.concatenate, - ( - *jax.vmap(_compute_row)( + f=jnp.concatenate, + tree=( + *jax.vmap(compute_row)( link_idx=jnp.array( model.kin_dyn_parameters.contact_parameters.body ), @@ -511,4 +570,5 @@ def _compute_row( ), ), ) + return a_ref, jnp.diag(R), K, D From 10204ea18c7ab83d63ee759089cb98bd10de5b0a Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 9 Oct 2024 14:46:41 +0200 Subject: [PATCH 06/13] Add new ContactParamsTypes alias --- src/jaxsim/api/contact.py | 7 +------ src/jaxsim/rbda/contacts/__init__.py | 7 +++++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index b736fc0e9..c0b9d4026 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -312,12 +312,7 @@ def estimate_good_soft_contacts_parameters( damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, **kwargs, -) -> ( - jaxsim.rbda.contacts.RelaxedRigidContactsParams - | jaxsim.rbda.contacts.RigidContactsParams - | jaxsim.rbda.contacts.SoftContactsParams - | jaxsim.rbda.contacts.ViscoElasticContactsParams -): +) -> jaxsim.rbda.contacts.ContactParamsTypes: """ Estimate good parameters for soft-like contact models. diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 71bd1647d..06646f14d 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -4,3 +4,10 @@ from .rigid import RigidContacts, RigidContactsParams from .soft import SoftContacts, SoftContactsParams from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams + +ContactParamsTypes = ( + SoftContactsParams + | RigidContactsParams + | RelaxedRigidContactsParams + | ViscoElasticContactsParams +) From 4e450871a98997585aff92e3f4cff60bc01269c7 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 10 Oct 2024 18:24:24 +0200 Subject: [PATCH 07/13] Add new jaxsim.api.contact.estimate_good_contact_parameters function --- src/jaxsim/api/contact.py | 96 +++++++++++++++++++++++++++++++++++---- src/jaxsim/api/data.py | 3 +- 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index c0b9d4026..4efa0b08e 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -9,7 +9,6 @@ import jaxsim.exceptions import jaxsim.terrain import jaxsim.typing as jtp -from jaxsim import logging from jaxsim.math import Adjoint, Cross, Transform from jaxsim.rbda import contacts @@ -303,6 +302,58 @@ def in_contact( return links_in_contact +def estimate_good_contact_parameters( + model: js.model.JaxSimModel, + *, + standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, + static_friction_coefficient: jtp.FloatLike = 0.5, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + max_penetration: jtp.FloatLike | None = None, + **kwargs, +) -> jaxsim.rbda.contacts.ContactParamsTypes: + """ + Estimate good contact parameters. + + Args: + model: The model to consider. + standard_gravity: The standard gravity constant. + static_friction_coefficient: The static friction coefficient. + number_of_active_collidable_points_steady_state: + The number of active collidable points in steady state supporting + the weight of the robot. + damping_ratio: The damping ratio. + max_penetration: + The maximum penetration allowed in steady state when the robot is + supported by the configured number of active collidable points. + kwargs: + Additional model-specific parameters passed to the builder method of + the parameters class. + + Returns: + The estimated good soft contacts parameters. + + Note: + This is primarily a convenience function for soft-like contact models. + However, it provides with some good default parameters also for the other ones. + + Note: + This method provides a good set of contacts parameters. + The user is encouraged to fine-tune the parameters based on the + specific application. + """ + + return estimate_good_soft_contacts_parameters( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + number_of_active_collidable_points_steady_state=number_of_active_collidable_points_steady_state, + damping_ratio=damping_ratio, + max_penetration=max_penetration, + **kwargs, + ) + + def estimate_good_soft_contacts_parameters( model: js.model.JaxSimModel, *, @@ -359,6 +410,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: max_δ = ( max_penetration if max_penetration is not None + # Consider as default a 0.5% of the model height. else 0.005 * estimate_model_height(model=model) ) @@ -376,8 +428,11 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: max_penetration=max_δ, number_of_active_collidable_points_steady_state=nc, damping_ratio=damping_ratio, - p=model.contact_model.parameters.p, - q=model.contact_model.parameters.q, + **dict( + p=model.contact_model.parameters.p, + q=model.contact_model.parameters.q, + ) + | kwargs, ) case contacts.ViscoElasticContacts(): @@ -391,15 +446,40 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: max_penetration=max_δ, number_of_active_collidable_points_steady_state=nc, damping_ratio=damping_ratio, - p=model.contact_model.parameters.p, - q=model.contact_model.parameters.q, - **kwargs, + **dict( + p=model.contact_model.parameters.p, + q=model.contact_model.parameters.q, + ) + | kwargs, ) ) + case contacts.RigidContacts(): + assert isinstance(model.contact_model, contacts.RigidContacts) + + # Disable Baumgarte stabilization by default since it does not play + # well with the forward Euler integrator. + K = kwargs.get("K", 0.0) + + parameters = contacts.RigidContactsParams.build( + mu=static_friction_coefficient, + **dict( + K=K, + D=2 * jnp.sqrt(K), + ) + | kwargs, + ) + + case contacts.RelaxedRigidContacts(): + assert isinstance(model.contact_model, contacts.RelaxedRigidContacts) + + parameters = contacts.RelaxedRigidContactsParams.build( + mu=static_friction_coefficient, + **kwargs, + ) + case _: - logging.warning("The active contact model is not soft-like, no-op.") - parameters = model.contact_model.parameters + raise ValueError(f"Invalid contact model: {model.contact_model}") return parameters diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index f9d141c0a..e50b95bca 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -224,7 +224,8 @@ def build( jaxsim.rbda.contacts.SoftContacts | jaxsim.rbda.contacts.ViscoElasticContacts, ): - contacts_params = js.contact.estimate_good_soft_contacts_parameters( + + contacts_params = js.contact.estimate_good_contact_parameters( model=model, standard_gravity=standard_gravity ) From 1da329ae3e5f37f70ecbede7402c88d4f414e5eb Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 11 Oct 2024 10:55:26 +0200 Subject: [PATCH 08/13] Add new tests for all contact models --- tests/test_simulations.py | 240 +++++++++++++++++++++++++++++++++++++- 1 file changed, 238 insertions(+), 2 deletions(-) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index bc928a957..c513ae1c2 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -1,3 +1,5 @@ +import functools + import jax import jax.numpy as jnp import pytest @@ -5,8 +7,8 @@ import jaxsim.api as js import jaxsim.integrators import jaxsim.rbda +import jaxsim.typing as jtp from jaxsim import VelRepr -from jaxsim.utils import Mutability def test_box_with_external_forces( @@ -102,7 +104,7 @@ def test_box_with_zero_gravity( model = jaxsim_model_box # Move the terrain (almost) infinitely far away from the box. - with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + with model.editable(validate=False) as model: model.terrain = jaxsim.terrain.FlatTerrain.build(height=-1e9) # Split the PRNG key. @@ -186,3 +188,237 @@ def test_box_with_zero_gravity( + 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, abs=1e-3, ) + + +def run_simulation( + model: js.model.JaxSimModel, + data_t0: js.data.JaxSimModelData, + dt: jtp.FloatLike, + tf: jtp.FloatLike, +) -> js.data.JaxSimModelData: + + @functools.cache + def get_integrator() -> tuple[jaxsim.integrators.Integrator, dict[str, jtp.PyTree]]: + + # Create the integrator. + integrator = jaxsim.integrators.fixed_step.Heun2.build( + fsal_enabled_if_supported=False, + dynamics=js.ode.wrap_system_dynamics_for_integration( + model=model, + data=data_t0, + system_dynamics=js.ode.system_dynamics, + ), + ) + + # Initialize the integrator state. + integrator_state_t0 = integrator.init(x0=data_t0.state, t0=0.0, dt=dt) + + return integrator, integrator_state_t0 + + # Initialize the integration horizon. + T_ns = jnp.arange(start=0.0, stop=int(tf * 1e9), step=int(dt * 1e9)).astype(int) + + # Initialize the simulation data. + integrator = None + integrator_state = None + data = data_t0.copy() + + for t_ns in T_ns: + + match model.contact_model: + + case jaxsim.rbda.contacts.ViscoElasticContacts(): + + data, _ = jaxsim.rbda.contacts.visco_elastic.step( + model=model, + data=data, + dt=dt, + ) + + case _: + + integrator, integrator_state = ( + get_integrator() if t_ns == 0 else (integrator, integrator_state) + ) + + data, integrator_state = js.model.step( + model=model, + data=data, + dt=dt, + integrator=integrator, + integrator_state=integrator_state, + ) + + return data + + +def test_simulation_with_soft_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.SoftContacts.build( + terrain=model.terrain, + ) + + # Initialize the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + contacts_params=js.contact.estimate_good_contact_parameters( + model=model, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=0.001, + ), + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) + + assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) + assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_visco_elastic_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.ViscoElasticContacts.build( + terrain=model.terrain, + ) + + # Initialize the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + contacts_params=js.contact.estimate_good_contact_parameters( + model=model, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=0.001, + ), + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) + + assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) + assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_rigid_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.RigidContacts.build( + terrain=model.terrain, + ) + + # Initialize the maximum penetration of each collidable point at steady state. + # This model is rigid, so we expect (almost) no penetration. + max_penetration = 0.000 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + # In order to achieve almost no penetration, we need to use a fairly large + # Baumgarte stabilization term. + contacts_params=js.contact.estimate_good_contact_parameters( + model=model, + K=100_000, + ), + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) + + assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) + assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_relaxed_rigid_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build( + terrain=model.terrain, + ) + + # Initialize the maximum penetration of each collidable point at steady state. + # This model is quasi-rigid, so we expect (almost) no penetration. + max_penetration = 0.000 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + # In order to achieve almost no penetration, we need to use a fairly large + # Baumgarte stabilization term. + contacts_params=js.contact.estimate_good_contact_parameters( + model=model, + time_constant=0.001, + ), + ) + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) + + # With this contact model, we need to slightly adjust the tolerance on xy. + assert data_tf.base_position()[0:2] == pytest.approx( + data_t0.base_position()[0:2], abs=0.000_010 + ) + assert data_tf.base_position()[2] + max_penetration == pytest.approx( + box_height / 2, abs=0.000_100 + ) From bc39bcc6652d9517bc1c9c07f421f2ad0bfb8e99 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 9 Oct 2024 14:54:40 +0200 Subject: [PATCH 09/13] Minor changes --- src/jaxsim/api/contact.py | 5 ++--- src/jaxsim/api/data.py | 2 +- src/jaxsim/api/model.py | 2 ++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 4efa0b08e..49c8ccda8 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -35,11 +35,10 @@ def collidable_point_kinematics( the linear component of the mixed 6D frame velocity. """ - from jaxsim.rbda import collidable_points - # Switch to inertial-fixed since the RBDAs expect velocities in this representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel( + + W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( model=model, base_position=data.base_position(), base_quaternion=data.base_orientation(dcm=False), diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index e50b95bca..36ec1eb16 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -34,7 +34,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): state: ODEState - gravity: jtp.Array + gravity: jtp.Vector contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 3621c0c59..0c9e63f62 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -40,6 +40,8 @@ class JaxSimModel(JaxsimDataclass): default_factory=jaxsim.terrain.FlatTerrain.build, repr=False ) + # Note that this is the default contact model. + # Its parameters, if any, are then overridden from those stored in JaxSimModelData. contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field( default=None, repr=False ) From c0ce67479d16c1b60a1f0f5297a6fc965b1bef9c Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 9 Oct 2024 14:59:04 +0200 Subject: [PATCH 10/13] Update signature of ContactModel.compute_contact_forces --- src/jaxsim/rbda/contacts/common.py | 7 ++++--- src/jaxsim/rbda/contacts/relaxed_rigid.py | 4 ++-- src/jaxsim/rbda/contacts/rigid.py | 4 ++-- src/jaxsim/rbda/contacts/soft.py | 13 ++++++++++++- src/jaxsim/rbda/contacts/visco_elastic.py | 2 +- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 5a2a6648a..e6892704a 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -131,7 +131,7 @@ def compute_contact_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, **kwargs, - ) -> tuple[jtp.Vector, tuple[Any, ...]]: + ) -> tuple[jtp.Matrix, tuple[Any, ...]]: """ Compute the contact forces. @@ -140,8 +140,9 @@ def compute_contact_forces( data: The data of the considered model. Returns: - A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame, - and as second element a tuple of optional additional information. + A tuple containing as first element the computed 6D contact force applied to + the contact points and expressed in the world frame, and as second element + a tuple of optional additional information. """ pass diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index bc2e46ecb..b991b7bbd 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -238,7 +238,7 @@ def compute_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Vector, tuple[Any, ...]]: + ) -> tuple[jtp.Matrix, tuple]: """ Compute the contact forces. @@ -252,7 +252,7 @@ def compute_contact_forces( Optional `(n_joints,)` vector of joint forces. Returns: - A tuple containing the contact forces. + A tuple containing as first element the computed contact forces. """ # Initialize the model and data this contact model is operating on. diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 8e40c1c1a..bdbbb2937 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -242,7 +242,7 @@ def compute_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Vector, tuple[Any, ...]]: + ) -> tuple[jtp.Matrix, tuple]: """ Compute the contact forces. @@ -256,7 +256,7 @@ def compute_contact_forces( Optional `(n_joints,)` vector of joint forces. Returns: - A tuple containing the contact forces. + A tuple containing as first element the computed contact forces. """ # Initialize the model and data this contact model is operating on. diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 8b626bc00..4af693527 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -423,7 +423,18 @@ def compute_contact_forces( self, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, - ) -> tuple[jtp.Vector, tuple[jtp.Vector]]: + ) -> tuple[jtp.Matrix, tuple[jtp.Matrix]]: + """ + Compute the contact forces. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + A tuple containing as first element the computed contact forces, and as + second element the derivative of the material deformation. + """ # Initialize the model and data this contact model is operating on. # This will raise an exception if either the contact model or the diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 4ba0233e5..99da7f9bd 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -266,7 +266,7 @@ def compute_contact_forces( dt: jtp.FloatLike | None = None, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Vector, tuple[Any, ...]]: + ) -> tuple[jtp.Matrix, tuple[jtp.Matrix, jtp.Matrix]]: """ Compute the contact forces. From 6c91b3bd1dd30b07aa0be84110e1a816d8ea327a Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 9 Oct 2024 15:02:42 +0200 Subject: [PATCH 11/13] Fix max_squarings argument of visco-elastic contact model --- src/jaxsim/rbda/contacts/visco_elastic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 99da7f9bd..25019fcfc 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -195,7 +195,7 @@ class ViscoElasticContacts(common.ContactModel): default_factory=FlatTerrain ) - max_squarings: jax_dataclasses.Static[int] = 25 + max_squarings: jax_dataclasses.Static[int] = dataclasses.field(default=25) @classmethod def build( @@ -239,7 +239,7 @@ def build( parameters=parameters, terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), max_squarings=int( - max_squarings or cls.__dataclass_fields__["max_squarings"].default() + max_squarings or cls.__dataclass_fields__["max_squarings"].default ), ) From 70aec8983a305f53029538ab9eeddc772aa139e7 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 17 Oct 2024 11:11:01 +0200 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 4 ++-- tests/test_simulations.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index b991b7bbd..25fb35b5e 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -265,13 +265,13 @@ def compute_contact_forces( jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) - ).astype(float) + ) joint_force_references = jnp.atleast_1d( jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None else jnp.zeros(model.number_of_joints()) - ).astype(float) + ) references = js.references.JaxSimModelReferences.build( model=model, diff --git a/tests/test_simulations.py b/tests/test_simulations.py index c513ae1c2..f93edcf0b 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -402,11 +402,11 @@ def test_simulation_with_relaxed_rigid_contacts( model=model, base_position=jnp.array([0.0, 0.0, box_height * 2]), velocity_representation=VelRepr.Inertial, - # In order to achieve almost no penetration, we need to use a fairly large - # Baumgarte stabilization term. + # For this contact model, the following method is practically no-op. + # Let's leave it there for consistency and to make sure that nothing + # gets broken if it is updated in the future. contacts_params=js.contact.estimate_good_contact_parameters( model=model, - time_constant=0.001, ), ) # =========================================== @@ -415,7 +415,7 @@ def test_simulation_with_relaxed_rigid_contacts( data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) - # With this contact model, we need to slightly adjust the tolerance on xy. + # With this contact model, we need to slightly increase the tolerances. assert data_tf.base_position()[0:2] == pytest.approx( data_t0.base_position()[0:2], abs=0.000_010 ) From 1e5d13f7c3db6c2b4bf8e466d74a6f2f19430ce5 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 10 Oct 2024 17:59:00 +0200 Subject: [PATCH 13/13] Deprecate jaxsim.api.contacts.estimate_good_soft_contacts_parameters Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- src/jaxsim/api/contact.py | 57 ++++++++------------------------------- 1 file changed, 11 insertions(+), 46 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 49c8ccda8..95655be0f 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -9,6 +9,7 @@ import jaxsim.exceptions import jaxsim.terrain import jaxsim.typing as jtp +from jaxsim import logging from jaxsim.math import Adjoint, Cross, Transform from jaxsim.rbda import contacts @@ -301,6 +302,15 @@ def in_contact( return links_in_contact +def estimate_good_soft_contacts_parameters( + *args, **kwargs +) -> jaxsim.rbda.contacts.ContactParamsTypes: + + msg = "This method is deprecated, please use `{}`." + logging.warning(msg.format(estimate_good_contact_parameters.__name__)) + return estimate_good_contact_parameters(*args, **kwargs) + + def estimate_good_contact_parameters( model: js.model.JaxSimModel, *, @@ -330,7 +340,7 @@ def estimate_good_contact_parameters( the parameters class. Returns: - The estimated good soft contacts parameters. + The estimated good contacts parameters. Note: This is primarily a convenience function for soft-like contact models. @@ -342,51 +352,6 @@ def estimate_good_contact_parameters( specific application. """ - return estimate_good_soft_contacts_parameters( - model=model, - standard_gravity=standard_gravity, - static_friction_coefficient=static_friction_coefficient, - number_of_active_collidable_points_steady_state=number_of_active_collidable_points_steady_state, - damping_ratio=damping_ratio, - max_penetration=max_penetration, - **kwargs, - ) - - -def estimate_good_soft_contacts_parameters( - model: js.model.JaxSimModel, - *, - standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, - static_friction_coefficient: jtp.FloatLike = 0.5, - number_of_active_collidable_points_steady_state: jtp.IntLike = 1, - damping_ratio: jtp.FloatLike = 1.0, - max_penetration: jtp.FloatLike | None = None, - **kwargs, -) -> jaxsim.rbda.contacts.ContactParamsTypes: - """ - Estimate good parameters for soft-like contact models. - - Args: - model: The model to consider. - standard_gravity: The standard gravity constant. - static_friction_coefficient: The static friction coefficient. - number_of_active_collidable_points_steady_state: - The number of active collidable points in steady state supporting - the weight of the robot. - damping_ratio: The damping ratio. - max_penetration: - The maximum penetration allowed in steady state when the robot is - supported by the configured number of active collidable points. - - Returns: - The estimated good soft contacts parameters. - - Note: - This method provides a good starting point for the soft contacts parameters. - The user is encouraged to fine-tune the parameters based on the - specific application. - """ - def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: """ Displacement between the CoM and the lowest collidable point using zero