From 83fcb94b87bcb41937dba5cc72e47cfd66cc293b Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 15 Oct 2024 14:04:33 +0200 Subject: [PATCH 1/6] Update cartpole model and add missing script to generate it --- examples/assets/build_cartpole_urdf.py | 202 +++++++++++++++++++++++++ examples/assets/cartpole.urdf | 45 +++--- 2 files changed, 223 insertions(+), 24 deletions(-) create mode 100644 examples/assets/build_cartpole_urdf.py diff --git a/examples/assets/build_cartpole_urdf.py b/examples/assets/build_cartpole_urdf.py new file mode 100644 index 000000000..2563093c9 --- /dev/null +++ b/examples/assets/build_cartpole_urdf.py @@ -0,0 +1,202 @@ +import os + +if "ROD_LOGGING_LEVEL" not in os.environ: + os.environ["ROD_LOGGING_LEVEL"] = "WARNING" + +import numpy as np +import rod.kinematics.tree_transforms +from rod.builder import primitives + +if __name__ == "__main__": + + # ================ + # Model parameters + # ================ + + # Rail parameters. + rail_height = 1.2 + rail_length = 5.0 + rail_radius = 0.005 + rail_mass = 5.0 + + # Cart parameters. + cart_mass = 1.0 + cart_size = (0.1, 0.2, 0.05) + + # Pole parameters. + pole_mass = 0.5 + pole_length = 1.0 + pole_radius = 0.005 + + # ======================== + # Create the link builders + # ======================== + + rail_builder = primitives.CylinderBuilder( + name="rail", + mass=rail_mass, + radius=rail_radius, + length=rail_length, + ) + + cart_builder = primitives.BoxBuilder( + name="cart", + mass=cart_mass, + x=cart_size[0], + y=cart_size[1], + z=cart_size[2], + ) + + pole_builder = primitives.CylinderBuilder( + name="pole", + mass=pole_mass, + radius=pole_radius, + length=pole_length, + ) + + # ================= + # Create the joints + # ================= + + world_to_rail = rod.Joint( + name="world_to_rail", + type="fixed", + parent="world", + child=rail_builder.name, + pose=primitives.PrimitiveBuilder.build_pose( + relative_to="world", + ), + ) + + linear = rod.Joint( + name="linear", + type="prismatic", + parent=rail_builder.name, + child=cart_builder.name, + pose=primitives.PrimitiveBuilder.build_pose( + relative_to=rail_builder.name, + pos=np.array([0, 0, rail_height]), + ), + axis=rod.Axis( + xyz=rod.Xyz(xyz=[0, 1, 0]), + limit=rod.Limit( + upper=(rail_length / 2 - cart_size[1] / 2), + lower=-(rail_length / 2 - cart_size[1] / 2), + effort=500.0, + velocity=10.0, + ), + ), + ) + + pivot = rod.Joint( + name="pivot", + type="continuous", + parent=cart_builder.name, + child=pole_builder.name, + pose=primitives.PrimitiveBuilder.build_pose( + relative_to=cart_builder.name, + ), + axis=rod.Axis( + xyz=rod.Xyz(xyz=[1, 0, 0]), + limit=rod.Limit(), + ), + ) + + # ================ + # Create the links + # ================ + + rail_elements_pose = primitives.PrimitiveBuilder.build_pose( + pos=np.array([0, 0, rail_height]), + rpy=np.array([np.pi / 2, 0, 0]), + ) + + rail = ( + rail_builder.build_link( + name=rail_builder.name, + pose=primitives.PrimitiveBuilder.build_pose( + relative_to=world_to_rail.name, + ), + ) + .add_inertial(pose=rail_elements_pose) + .add_visual(pose=rail_elements_pose) + .add_collision(pose=rail_elements_pose) + .build() + ) + + cart = ( + cart_builder.build_link( + name=cart_builder.name, + pose=primitives.PrimitiveBuilder.build_pose(relative_to=linear.name), + ) + .add_inertial() + .add_visual() + .add_collision() + .build() + ) + + pole_elements_pose = primitives.PrimitiveBuilder.build_pose( + pos=np.array([0, 0, pole_length / 2]), + ) + + pole = ( + pole_builder.build_link( + name=pole_builder.name, + pose=primitives.PrimitiveBuilder.build_pose( + relative_to=pivot.name, + ), + ) + .add_inertial(pose=pole_elements_pose) + .add_visual(pose=pole_elements_pose) + .add_collision(pose=pole_elements_pose) + .build() + ) + + # =========== + # Build model + # =========== + + # Create ROD in-memory model. + model = rod.Model( + name="cartpole", + canonical_link=rail.name, + link=[ + rail, + cart, + pole, + ], + joint=[ + world_to_rail, + linear, + pivot, + ], + ) + + # Update the pose elements to be closer to those expected in URDF. + model.switch_frame_convention( + frame_convention=rod.FrameConvention.Urdf, explicit_frames=True + ) + + # ============== + # Get SDF string + # ============== + + # Create the top-level SDF object. + sdf = rod.Sdf(version="1.10", model=model) + + # Generate the SDF string. + # sdf_string = sdf.serialize(pretty=True, validate=True) + + # =============== + # Get URDF string + # =============== + + import rod.urdf.exporter + + # Convert the SDF to URDF. + urdf_string = rod.urdf.exporter.UrdfExporter( + pretty=True, indent=" " + ).to_urdf_string(sdf=sdf) + + # Print the URDF string. + print(urdf_string) diff --git a/examples/assets/cartpole.urdf b/examples/assets/cartpole.urdf index 8b8c9dc93..5bea4829a 100644 --- a/examples/assets/cartpole.urdf +++ b/examples/assets/cartpole.urdf @@ -1,5 +1,4 @@ - @@ -22,29 +21,29 @@ - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - + + @@ -71,13 +70,11 @@ - + - - + From 2616b910a56cfe7fbfe478a315d76f53a590dc18 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 17 Oct 2024 13:33:45 +0200 Subject: [PATCH 2/6] Update parallel computing example --- examples/Parallel_computing.ipynb | 317 ---------------- examples/jaxsim_as_physics_engine.ipynb | 481 ++++++++++++++++++++++++ 2 files changed, 481 insertions(+), 317 deletions(-) delete mode 100644 examples/Parallel_computing.ipynb create mode 100644 examples/jaxsim_as_physics_engine.ipynb diff --git a/examples/Parallel_computing.ipynb b/examples/Parallel_computing.ipynb deleted file mode 100644 index bec98076c..000000000 --- a/examples/Parallel_computing.ipynb +++ /dev/null @@ -1,317 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# `JAXsim` Showcase: Parallel Simulation of a free-falling body" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we install the necessary packages and import them." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# @title Imports and setup\n", - "import sys\n", - "\n", - "IS_COLAB = \"google.colab\" in sys.modules\n", - "\n", - "# Install JAX and Gazebo\n", - "if IS_COLAB:\n", - " !{sys.executable} -m pip install -qU jaxsim\n", - " !apt install -qq lsb-release wget gnupg\n", - " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", - " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", - " !apt -qq update\n", - " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", - "\n", - "# Set environment variable to avoid GPU out of memory errors\n", - "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", - "\n", - "import time\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import rod\n", - "from rod.builder.primitives import SphereBuilder\n", - "\n", - "from jaxsim import logging\n", - "\n", - "logging.set_logging_level(logging.LoggingLevel.INFO)\n", - "logging.info(f\"Running on {jax.devices()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will use a simple sphere model to simulate a free-falling body. The spheres set will be composed of 9 spheres, each with a different position. The spheres will be simulated in parallel, and the simulation will be run for 3000 steps corresponding to 3 seconds of simulation.\n", - "\n", - "**Note**: Parallel simulations are independent of each other, the different position is imposed only to show the parallelization visually." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# @title Create a sphere model\n", - "model_sdf_string = rod.Sdf(\n", - " version=\"1.7\",\n", - " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", - " .build_model()\n", - " .add_link()\n", - " .add_inertial()\n", - " .add_visual()\n", - " .add_collision()\n", - " .build(),\n", - ").serialize(pretty=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", - "\n", - "- `model`: an object that defines the dynamics of the system.\n", - "- `data`: an object that contains the state of the system.\n", - "- `integrator`: an object that defines the integration method.\n", - "- `integrator_state`: an object that contains the state of the integrator." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import jaxsim.api as js\n", - "from jaxsim import integrators\n", - "\n", - "dt = 0.001\n", - "integration_time = 1.5 # seconds\n", - "\n", - "model = js.model.JaxSimModel.build_from_model_description(\n", - " model_description=model_sdf_string\n", - ")\n", - "data = js.data.JaxSimModelData.build(model=model)\n", - "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n", - " dynamics=js.ode.wrap_system_dynamics_for_integration(\n", - " model=model,\n", - " data=data,\n", - " system_dynamics=js.ode.system_dynamics,\n", - " ),\n", - ")\n", - "integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It is possible to automatically choose a good set of parameters for the terrain. \n", - "\n", - "By default, in JaxSim a sphere primitive has 250 collision points. This can be modified by setting the `JAXSIM_COLLISION_SPHERE_POINTS` environment variable.\n", - "\n", - "Given that at its steady-state the sphere will act on two or three points, we can estimate the ground parameters by explicitly setting the number of active points to these values. \n", - "\n", - "Eventually, you can specify the maximum penetration depth of the sphere into the terrain by passing `max_penetraion` to the `estimate_good_soft_contacts_parameters` function." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data = data.replace(\n", - " contacts_params=js.contact.estimate_good_soft_contacts_parameters(\n", - " model=model,\n", - " number_of_active_collidable_points_steady_state=3,\n", - " max_penetration=None,\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create a position vector for a 4x4 grid. Every sphere will be placed at a different height." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Primary Calculations\n", - "envs_per_row = 4 # @slider(2, 10, 1)\n", - "\n", - "env_spacing = 0.5\n", - "edge_len = env_spacing * (2 * envs_per_row - 1)\n", - "\n", - "\n", - "# Create Grid\n", - "def grid(edge_len, envs_per_row):\n", - " edge = jnp.linspace(-edge_len, edge_len, envs_per_row)\n", - " xx, yy = jnp.meshgrid(edge, edge)\n", - " zz = 0.2 + 0.1 * (\n", - " jnp.arange(envs_per_row**2) % envs_per_row\n", - " + jnp.arange(envs_per_row**2) // envs_per_row\n", - " )\n", - " zz = zz.reshape(envs_per_row, envs_per_row)\n", - " poses = jnp.stack([xx, yy, zz], axis=-1).reshape(envs_per_row**2, 3)\n", - " return poses\n", - "\n", - "\n", - "logging.info(f\"Simulating {envs_per_row**2} environments\")\n", - "poses = grid(edge_len, envs_per_row)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In order to parallelize the simulation, we first need to define a function `simulate` for a single element of the batch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Define a function to simulate a single model instance\n", - "def simulate(\n", - " data: js.data.JaxSimModelData, integrator_state: dict, pose: jnp.array\n", - ") -> tuple:\n", - " # Set the base position to the initial pose\n", - " data = data.reset_base_position(base_position=pose)\n", - "\n", - " # Create a list to store the base position over time\n", - " x_t_i = []\n", - "\n", - " for _ in range(int(integration_time // dt)):\n", - " data, integrator_state = js.model.step(\n", - " dt=dt,\n", - " model=model,\n", - " data=data,\n", - " integrator=integrator,\n", - " integrator_state=integrator_state,\n", - " joint_force_references=None,\n", - " link_forces=None,\n", - " )\n", - " x_t_i.append(data.base_position())\n", - "\n", - " return x_t_i" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will make use of `jax.vmap` to simulate multiple models in parallel. This is a very powerful feature of JAX that allows to write code that is very similar to the single-model case, but can be executed in parallel on multiple models.\n", - "In order to do so, we need to first apply `jax.vmap` to the `simulate` function, and then call the resulting function with the batch of different poses as input.\n", - "\n", - "Note that in our case we are vectorizing over the `pose` argument of the function `simulate`, this correspond to the value assigned to the `in_axes` parameter of `jax.vmap`:\n", - "\n", - "`in_axes=(None, None, 0)` means that the first two arguments of `simulate` are not vectorized, while the third argument is vectorized over the zero-th dimension." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Define a function to simulate multiple model instances\n", - "simulate_vectorized = jax.vmap(simulate, in_axes=(None, None, 0))\n", - "\n", - "# Run and time the simulation\n", - "now = time.perf_counter()\n", - "\n", - "x_t = simulate_vectorized(data, integrator_state, poses)\n", - "\n", - "comp_time = time.perf_counter() - now\n", - "\n", - "logging.info(\n", - " f\"Running simulation with {envs_per_row**2} models took {comp_time} seconds.\"\n", - ")\n", - "logging.info(\n", - " f\"This corresponds to an RTF (Real Time Factor) of {(envs_per_row**2 * integration_time / comp_time):.2f}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's extract the data from the simulation and plot it. We expect to see the height time series of each sphere starting from a different value." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "\n", - "plt.plot(np.arange(len(x_t)) * dt, np.array(x_t)[:, :, 2])\n", - "plt.grid(True)\n", - "plt.xlabel(\"Time [s]\")\n", - "plt.ylabel(\"Height [m]\")\n", - "plt.title(\"Trajectory of the model's base\")\n", - "plt.show()" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuClass": "premium", - "gpuType": "V100", - "private_outputs": true, - "provenance": [ - { - "file_id": "1QsuS7EJhdPEHxxAu9XwozvA7eb4ZnlAb", - "timestamp": 1701993737024 - } - ], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb new file mode 100644 index 000000000..b67331c8b --- /dev/null +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -0,0 +1,481 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "H-WgcgGQaTG7" + }, + "source": [ + "# JaxSim as a hardware-accelerated parallel physics engine\n", + "\n", + "JaxSim was originally developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", + "\n", + "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n", + "\n", + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SgOSnrSscEkt" + }, + "source": [ + "## Prepare the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fdqvAqMDaTG9" + }, + "outputs": [], + "source": [ + "# @title Imports and setup\n", + "import sys\n", + "from IPython.display import clear_output\n", + "\n", + "IS_COLAB = \"google.colab\" in sys.modules\n", + "\n", + "# Install JAX and Gazebo\n", + "if IS_COLAB:\n", + " !{sys.executable} -m pip install --pre -qU jaxsim\n", + " !apt install -qq lsb-release wget gnupg\n", + " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", + " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", + " !apt -qq update\n", + " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", + "\n", + " clear_output()\n", + "\n", + "# Set environment variable to avoid GPU out of memory errors\n", + "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", + "\n", + "# ================\n", + "# Notebook imports\n", + "# ================\n", + "\n", + "import os\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jaxsim.api as js\n", + "import rod\n", + "from jaxsim import integrators, logging\n", + "from rod.builder.primitives import SphereBuilder\n", + "\n", + "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", + "print(f\"Running on {jax.devices()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QtCCUhdpdGFH" + }, + "source": [ + "## Prepare the simulation\n", + "\n", + "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`ami-iit/rod`][rod] library, which processes these formats.\n", + "\n", + "The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n", + "\n", + "[sdformat]: http://sdformat.org/\n", + "[urdf]: http://wiki.ros.org/urdf/\n", + "[rod]: https://github.com/ami-iit/rod" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "0emoMQhCaTG_" + }, + "outputs": [], + "source": [ + "# @title Create the model description of a sphere\n", + "\n", + "# Create a SDF model.\n", + "# The builder takes care to compute the right inertia tensor for you.\n", + "rod_sdf = rod.Sdf(\n", + " version=\"1.7\",\n", + " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", + " .build_model()\n", + " .add_link()\n", + " .add_inertial()\n", + " .add_visual()\n", + " .add_collision()\n", + " .build(),\n", + ")\n", + "\n", + "# Rod allows to update the frames w.r.t. the poses are expressed.\n", + "rod_sdf.model.switch_frame_convention(\n", + " frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n", + ")\n", + "\n", + "# Serialize the model to a SDF string.\n", + "model_sdf_string = rod_sdf.serialize(pretty=True)\n", + "print(model_sdf_string)\n", + "\n", + "# JaxSim currently only supports collisions between points attached to bodies\n", + "# and a ground surface modeled as a heighmap sampled from a smooth function.\n", + "# While this approach is universal as it applies to generic meshes, the number\n", + "# of considered points greatly affects the performance. Spheres, by default,\n", + "# are discretized with 250 points. It's too much for this simple example.\n", + "# This number can be decreased with the following environment variable.\n", + "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NqjuZKvOaTG_" + }, + "source": [ + "### Create the model and its data\n", + "\n", + "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", + "\n", + "- `model`: an object that defines the dynamics of the system.\n", + "- `data`: an object that contains the state of the system.\n", + "- `integrator`: an object that defines the integration method.\n", + "- `integrator_state`: an object that contains the state of the integrator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "etQ577cFaTHA" + }, + "outputs": [], + "source": [ + "# Create the JaxSim model.\n", + "# This is shared among all the parallel instances.\n", + "model = js.model.JaxSimModel.build_from_model_description(\n", + " model_description=model_sdf_string, time_step=0.001\n", + ")\n", + "\n", + "# Create the data of a single model.\n", + "# We will create a vectorized instance later.\n", + "data_single = js.data.JaxSimModelData.zero(model=model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FJF-HoWaiK9J" + }, + "source": [ + "### Select the contact model\n", + "\n", + "JaxSim offers several contact models, with the default being the non-linear Hunt/Crossley soft contact model. This model supports stick/slip transitions and fully accounts for friction cones.\n", + "\n", + "While it is faster than other models, it requires careful parameter tuning and may need a small time step $\\Delta t$, unless a variable-step integrator is used.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VAaitHRKjnwc" + }, + "outputs": [], + "source": [ + "import jaxsim\n", + "\n", + "# Operate on a copy of the model.\n", + "# When validate=True, this context manager ensures that the PyTree structure\n", + "# of the object is not altered. This is a nice feature of JaxSim to spot\n", + "# earlier user logic that might trigger unwanted JIT recompilations.\n", + "# In this case, we need to disable validation since PyTree structure might\n", + "# change if you use a contact model different from the default.\n", + "with model.editable(validate=False) as model:\n", + "\n", + " # The SoftContacts class can be replaced with a different contact model.\n", + " model.contact_model = jaxsim.rbda.contacts.SoftContacts.build(\n", + " # JaxSim provides the following helper that estimates good contact\n", + " # parameters. While they might not be optimal, usually are a good\n", + " # starting point. Users are encouraged to fine-tune them.\n", + " parameters=js.contact.estimate_good_contact_parameters(\n", + " model=model,\n", + " number_of_active_collidable_points_steady_state=4,\n", + " max_penetration=0.001,\n", + " )\n", + " )\n", + "\n", + "# Print the contact parameters.\n", + "# Note that these parameters are the nominal parameters shared among\n", + "# all parallel instances. If needed, they can be overidden in the\n", + "# vectorized data object that will be created later.\n", + "print(model.contact_model.parameters)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6REY2bq3lc_k" + }, + "source": [ + "### Select the integrator\n", + "\n", + "JaxSim offers various integrators, ranging from basic ones like `ForwardEuler` to higher-order methods like `RungeKutta4`. You can explore the available integrators in the following modules:\n", + "\n", + "- `jaxsim.integrators.fixed_step`\n", + "- `jaxsim.integrators.variable_step`\n", + "\n", + "The `*SO3` variants update the integration scheme by integrating more accurately the base orientation on the $\\text{SO}(3)$ manifold." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o86Teq5piVGj" + }, + "outputs": [], + "source": [ + "# Create the integrator.\n", + "integrator = integrators.fixed_step.Heun2SO3.build(\n", + " dynamics=js.ode.wrap_system_dynamics_for_integration(\n", + " model=model,\n", + " data=data_single,\n", + " system_dynamics=js.ode.system_dynamics,\n", + " ),\n", + ")\n", + "\n", + "# Initialize the integrator.\n", + "integrator_state = integrator.init(\n", + " x0=data_single.state,\n", + " t0=0.0,\n", + " dt=model.time_step,\n", + ")\n", + "\n", + "# Initialize the simulated time.\n", + "T = jnp.arange(start=0, stop=1.0, step=model.time_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V6IeD2B3m4F0" + }, + "source": [ + "## Sample a batch of trajectories in parallel\n", + "\n", + "With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n", + "\n", + "In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n", + "\n", + "Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vtEn0aIzr_2j" + }, + "outputs": [], + "source": [ + "# @title Generate batched initial data\n", + "\n", + "# Create a random JAX key.\n", + "key = jax.random.PRNGKey(seed=0)\n", + "\n", + "# Split subkeys for sampling random initial data.\n", + "batch_size = 10\n", + "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", + "\n", + "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n", + "data_batch_t0 = jax.vmap(\n", + " lambda key: js.data.random_model_data(\n", + " model=model,\n", + " key=key,\n", + " base_pos_bounds=([0, 0, 0.3], [0, 0, 0.6]),\n", + " base_vel_lin_bounds=(0, 0),\n", + " base_vel_ang_bounds=(0, 0),\n", + " )\n", + ")(jnp.vstack(subkeys))\n", + "\n", + "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0tQPfsl6uxHm" + }, + "outputs": [], + "source": [ + "# @title Create parallel step function\n", + "\n", + "import functools\n", + "from typing import Any\n", + "\n", + "\n", + "@jax.jit\n", + "def step_single(\n", + " model: js.model.JaxSimModel,\n", + " data: js.data.JaxSimModelData,\n", + " integrator_state: dict[str, Any],\n", + ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", + "\n", + " # Close step over static arguments.\n", + " return js.model.step(\n", + " model=model,\n", + " data=data,\n", + " integrator=integrator,\n", + " integrator_state=integrator_state,\n", + " link_forces=None,\n", + " joint_force_references=None,\n", + " )\n", + "\n", + "\n", + "@jax.jit\n", + "@functools.partial(jax.vmap, in_axes=(None, 0, None))\n", + "def step_parallel(\n", + " model: js.model.JaxSimModel,\n", + " data: js.data.JaxSimModelData,\n", + " integrator_state: dict[str, Any],\n", + ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", + "\n", + " return step_single(\n", + " model=model, data=data, integrator_state=integrator_state\n", + " )\n", + "\n", + "\n", + "# The first run will be slow since JAX needs to JIT-compile the functions.\n", + "_ = step_single(model, data_single, integrator_state)\n", + "_ = step_parallel(model, data_batch_t0, integrator_state)\n", + "\n", + "# Benchmark the execution of a single step.\n", + "print(\"\\nSingle simulation step:\")\n", + "%timeit step_single(model, data_single, integrator_state)\n", + "\n", + "# On hardware accelerators, there's a range of batch_size values where\n", + "# increasing the number of parallel instances doesn't affect computation time.\n", + "# This range depends on the GPU/TPU specifications.\n", + "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n", + "%timeit step_parallel(model, data_batch_t0, integrator_state)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VNwzT2JQ1n15" + }, + "outputs": [], + "source": [ + "# @title Run parallel simulation\n", + "\n", + "data = data_batch_t0\n", + "data_trajectory_list = []\n", + "\n", + "for _ in T:\n", + "\n", + " data, integrator_state = step_parallel(model, data, integrator_state)\n", + " data_trajectory_list.append(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y6n720Cr3G44" + }, + "source": [ + "## Visualize trajectory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BLPODyKr3Lyg" + }, + "outputs": [], + "source": [ + "# Convert a list of PyTrees to a batched PyTree.\n", + "# This operation is called 'tree transpose' in JAX.\n", + "data_trajectory = jax.tree.map(\n", + " lambda *leafs: jnp.stack(leafs), *data_trajectory_list\n", + ")\n", + "\n", + "print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-jxJXy5r3RMt" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n", + "plt.grid(True)\n", + "plt.xlabel(\"Time [s]\")\n", + "plt.ylabel(\"Height [m]\")\n", + "plt.title(\"Height trajectory of the sphere\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N92-WjPFGuua" + }, + "source": [ + "# Conclusions\n", + "\n", + "This notebook introduced the key APIs of JaxSim as a hardware-accelerated parallel physics engine. Key takeaways:\n", + "\n", + "- **Contact models**: trajectories are sensitive to the contact model used. Explore the `jaxsim.rbda.contacts` package to find the best fit, as each model comes with trade-offs.\n", + "- **Integrator selection**: the choice of integrator affects both accuracy and speed. Experiment with options in the `jaxsim.integrators` package to optimize for your application and hardware accelerator.\n", + "- **Time step**: the interaction between contact models and integrators depends on the integration step $\\Delta t$. Choose the largest stable time step that guarantees for stable simulations.\n", + "- **Automatic vectorization**: this notebook demonstrated one way to use `jax.vmap`, but there are many other approaches. As you become more familiar with JAX, you'll discover better methods tailored to your needs.\n", + "- **Advanced applications**: Combine `jax.jit` and `jax.vmap` with `jax.grad`, `jax.jacfwd`, and `jax.jacrev` for gradient-based learning and other advanced tasks (not covered here).\n", + "\n", + "Have fun!" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuClass": "premium", + "gpuType": "T4", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 2babb50de87ac2a82cd55d8ca3f12bd7697c4981 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 17 Oct 2024 13:33:53 +0200 Subject: [PATCH 3/6] Update PD controller example --- examples/PD_controller.ipynb | 328 -------------- examples/jaxsim_for_robot_controllers.ipynb | 478 ++++++++++++++++++++ 2 files changed, 478 insertions(+), 328 deletions(-) delete mode 100644 examples/PD_controller.ipynb create mode 100644 examples/jaxsim_for_robot_controllers.ipynb diff --git a/examples/PD_controller.ipynb b/examples/PD_controller.ipynb deleted file mode 100644 index 378d53cdf..000000000 --- a/examples/PD_controller.ipynb +++ /dev/null @@ -1,328 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# `JAXsim` Showcase: PD Controller\n", - "\n", - "First, we install the necessary packages and import them.\n", - "\n", - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# @title Imports and setup\n", - "from IPython.display import clear_output\n", - "import os\n", - "import sys\n", - "\n", - "IS_COLAB = \"google.colab\" in sys.modules\n", - "\n", - "# Install JAX and Gazebo SDF\n", - "if IS_COLAB:\n", - " !{sys.executable} -m pip install -qU jaxsim[viz]\n", - " !apt install -qq lsb-release wget gnupg\n", - " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", - " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", - " !apt -qq update\n", - " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", - "\n", - " # Install dependencies for visualization on Colab and ReadTheDocs\n", - " !sudo apt update\n", - " !apt install libosmesa6-dev\n", - " clear_output()\n", - "\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "from jaxsim import logging\n", - "\n", - "logging.set_logging_level(logging.LoggingLevel.INFO)\n", - "logging.info(f\"Running on {jax.devices()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# @title Fetch the URDF file\n", - "import requests\n", - "\n", - "url = \"https://raw.githubusercontent.com/ami-iit/jaxsim/main/examples/assets/cartpole.urdf\"\n", - "\n", - "response = requests.get(url)\n", - "if response.status_code == 200:\n", - " model_urdf_string = response.text\n", - "else:\n", - " logging.error(\"Failed to fetch data\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", - "\n", - "- `model`: an object that defines the dynamics of the system.\n", - "- `data`: an object that contains the state of the system.\n", - "- `integrator`: an object that defines the integration method.\n", - "- `integrator_state`: an object that contains the state of the integrator." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import jaxsim.api as js\n", - "from jaxsim import integrators\n", - "\n", - "dt = 0.01\n", - "integration_time = 5.0\n", - "num_steps = int(integration_time / dt)\n", - "\n", - "model = js.model.JaxSimModel.build_from_model_description(\n", - " model_description=model_urdf_string\n", - ")\n", - "data = js.data.JaxSimModelData.build(model=model)\n", - "integrator = integrators.fixed_step.RungeKutta4SO3.build(\n", - " dynamics=js.ode.wrap_system_dynamics_for_integration(\n", - " model=model,\n", - " data=data,\n", - " system_dynamics=js.ode.system_dynamics,\n", - " ),\n", - ")\n", - "integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's reset the cartpole to a random state." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "random_positions = jax.random.uniform(\n", - " minval=-1.0, maxval=1.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\n", - ")\n", - "\n", - "data = data.reset_joint_positions(positions=random_positions)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The visualization is done using mujoco package, to be able to render easily the animations also on Google Colab. If you are not interested in the animation, execute but do not try to understand deeply this cell." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# @title Set up MuJoCo renderer\n", - "os.environ[\"MUJOCO_GL\"] = \"osmesa\"\n", - "\n", - "from jaxsim.mujoco import MujocoModelHelper, MujocoVideoRecorder\n", - "from jaxsim.mujoco.loaders import UrdfToMjcf, MujocoCamera\n", - "\n", - "mjcf_string, assets = UrdfToMjcf.convert(\n", - " urdf=model.built_from,\n", - " cameras=MujocoCamera.build_from_target_view(\n", - " camera_name=\"cartpole_camera\",\n", - " lookat=jnp.array([0.0, data.joint_positions()[0], 1.2]),\n", - " distance=3,\n", - " azimut=150,\n", - " elevation=-10,\n", - " ),\n", - ")\n", - "mj_model_helper = MujocoModelHelper.build_from_xml(\n", - " mjcf_description=mjcf_string, assets=assets\n", - ")\n", - "\n", - "# Create the video recorder.\n", - "recorder = MujocoVideoRecorder(\n", - " model=mj_model_helper.model,\n", - " data=mj_model_helper.data,\n", - " fps=int(1 / 0.010),\n", - " width=320 * 2,\n", - " height=240 * 2,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's see how the model behaves when not controlled:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import mediapy as media\n", - "\n", - "for _ in range(num_steps):\n", - " data, integrator_state = js.model.step(\n", - " dt=dt,\n", - " model=model,\n", - " data=data,\n", - " integrator=integrator,\n", - " integrator_state=integrator_state,\n", - " joint_force_references=None,\n", - " link_forces=None,\n", - " )\n", - "\n", - " mj_model_helper.set_joint_positions(\n", - " positions=data.joint_positions(), joint_names=model.joint_names()\n", - " )\n", - "\n", - " recorder.record_frame(camera_name=\"cartpole_camera\")\n", - "\n", - "media.show_video(recorder.frames, fps=1 / dt)\n", - "recorder.frames = []" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's now define the PD controller. We will use the following equations:\n", - "\n", - "\\begin{align} \n", - "\\mathbf{M}\\ddot{s} + \\underbrace{\\mathbf{C}\\dot{s} + \\mathbf{G}}_{\\mathbf{H}} = \\tau \\\\\n", - "\\tau = \\mathbf{H} - \\mathbf{K}_p(s - s_d) - \\mathbf{K}_d(\\dot{s} - \\dot{s}_d)\n", - "\\end{align}\n", - "\n", - "where $\\mathbf{M}$ is the mass matrix, $\\mathbf{C}$ is the Coriolis matrix, $\\mathbf{G}$ is the gravity vector, $\\mathbf{K}_p$ is the proportional gain matrix, $\\mathbf{K}_d$ is the derivative gain matrix, $s$ is the position vector, $\\dot{s}$ is the velocity vector, $\\ddot{s}$ is the acceleration vector, and $s_d$ and $\\dot{s}_d$ are the desired position and velocity vectors, respectively." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Define the PD gains\n", - "KP = 10.0\n", - "KD = 6.0\n", - "\n", - "\n", - "def pd_controller(\n", - " data: js.data.JaxSimModelData, q_d: jax.Array, q_dot_d: jax.Array\n", - ") -> jax.Array:\n", - "\n", - " # Compute the gravity compensation term\n", - " H = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n", - "\n", - " q = data.joint_positions()\n", - " q_dot = data.joint_velocities()\n", - "\n", - " return H + KP * (q_d - q) + KD * (q_dot_d - q_dot)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for _ in range(num_steps):\n", - " control_torques = pd_controller(\n", - " data=data,\n", - " q_d=jnp.array([0.0, 0.0]),\n", - " q_dot_d=jnp.array([0.0, 0.0]),\n", - " )\n", - "\n", - " data, integrator_state = js.model.step(\n", - " dt=dt,\n", - " model=model,\n", - " data=data,\n", - " integrator=integrator,\n", - " integrator_state=integrator_state,\n", - " joint_force_references=control_torques,\n", - " link_forces=None,\n", - " )\n", - "\n", - " mj_model_helper.set_joint_positions(\n", - " positions=data.joint_positions(), joint_names=model.joint_names()\n", - " )\n", - "\n", - " recorder.record_frame(camera_name=\"cartpole_camera\")\n", - "\n", - "media.show_video(recorder.frames, fps=1 / dt)\n", - "recorder.frames = []" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuClass": "premium", - "gpuType": "V100", - "private_outputs": true, - "provenance": [ - { - "file_id": "1QsuS7EJhdPEHxxAu9XwozvA7eb4ZnlAb", - "timestamp": 1701993737024 - } - ], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/jaxsim_for_robot_controllers.ipynb b/examples/jaxsim_for_robot_controllers.ipynb new file mode 100644 index 000000000..6bb5ea4ae --- /dev/null +++ b/examples/jaxsim_for_robot_controllers.ipynb @@ -0,0 +1,478 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "EhPy6FgiZH4d" + }, + "source": [ + "# JaxSim for developing closed-loop robot controllers\n", + "\n", + "Originally developed as a **hardware-accelerated physics engine**, JaxSim has expanded its capabilities to become a full-featured **JAX-based multibody dynamics library**.\n", + "\n", + "In this notebook, you'll explore how to combine these two core features. Specifically, you'll learn how to load a robot model and design a model-based controller for closed-loop simulations.\n", + "\n", + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vsf1AlxdZH4f" + }, + "outputs": [], + "source": [ + "# @title Prepare the environment\n", + "from IPython.display import clear_output\n", + "import sys\n", + "\n", + "IS_COLAB = \"google.colab\" in sys.modules\n", + "\n", + "# Install JAX, sdformat, and other notebook dependencies.\n", + "if IS_COLAB:\n", + " !{sys.executable} -m pip install --pre -qU jaxsim[viz]\n", + " !apt install -qq lsb-release wget gnupg\n", + " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", + " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", + " !apt -qq update\n", + " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", + "\n", + " # Install dependencies for visualization on Colab and ReadTheDocs.\n", + " !sudo apt update\n", + " !apt install libosmesa6-dev\n", + " clear_output()\n", + "\n", + "# ================\n", + "# Notebook imports\n", + "# ================\n", + "\n", + "import os\n", + "\n", + "\n", + "os.environ[\"MUJOCO_GL\"] = \"osmesa\"\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jaxsim.mujoco\n", + "from jaxsim import logging\n", + "\n", + "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", + "print(f\"Running on {jax.devices()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kN-b9nOsZH4g" + }, + "source": [ + "We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5aLqrZDqR5LA" + }, + "source": [ + "## Prepare the simulation\n", + "\n", + "JaxSim supports loading robot models from both [SDF][sdformat] and [URDF][urdf] files, utilizing the [`ami-iit/rod`][rod] library for processing these formats.\n", + "\n", + "The `rod` library library can read URDF files and validates them internally using [`gazebosim/sdformat`][sdformat_github]. In this example, we'll load a cart-pole model, which will be used to create the JaxSim simulation model.\n", + "\n", + "[sdformat]: http://sdformat.org/\n", + "[urdf]: http://wiki.ros.org/urdf/\n", + "[rod]: https://github.com/ami-iit/rod\n", + "[sdformat_github]: https://github.com/gazebosim/sdformat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PZM7hEvFZH4h" + }, + "outputs": [], + "source": [ + "# @title Download the URDF model\n", + "\n", + "import requests\n", + "\n", + "url = \"https://raw.githubusercontent.com/ami-iit/jaxsim/main/examples/assets/cartpole.urdf\"\n", + "\n", + "response = requests.get(url)\n", + "\n", + "if response.status_code == 200:\n", + " model_urdf_string = response.text\n", + "else:\n", + " logging.error(\"Failed to fetch data\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M5XsKehvZH4j" + }, + "outputs": [], + "source": [ + "# @title Create the model and its data\n", + "\n", + "import jaxsim.api as js\n", + "from jaxsim import integrators\n", + "\n", + "# Create the model from the model description.\n", + "model = js.model.JaxSimModel.build_from_model_description(\n", + " model_description=model_urdf_string,\n", + " time_step=0.010,\n", + ")\n", + "\n", + "# Create the data storing the simulation state.\n", + "data_zero = js.data.JaxSimModelData.zero(model=model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jk9csR5ETgn1" + }, + "outputs": [], + "source": [ + "# @title Select the integrator\n", + "\n", + "# Create the integrator.\n", + "integrator = integrators.fixed_step.RungeKutta4.build(\n", + " dynamics=js.ode.wrap_system_dynamics_for_integration(\n", + " model=model,\n", + " data=data_zero,\n", + " system_dynamics=js.ode.system_dynamics,\n", + " ),\n", + ")\n", + "\n", + "# Initialize the integrator.\n", + "integrator_state = integrator.init(\n", + " x0=data_zero.state,\n", + " t0=0.0,\n", + " dt=model.time_step,\n", + ")\n", + "\n", + "# Initialize the simulated time.\n", + "T = jnp.arange(start=0, stop=5.0, step=model.time_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bo6Ke5nAWL-S" + }, + "source": [ + "## Prepare the MuJoCo renderer\n", + "\n", + "For visualization purpose, we use the passive viewer of the MuJoCo simulator. It allows to either open an interactive windows when used locally or record a video when used in notebooks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "j1_I2i5TZH4n" + }, + "outputs": [], + "source": [ + "# Create the MJCF resources from the URDF.\n", + "mjcf_string, assets = jaxsim.mujoco.UrdfToMjcf.convert(\n", + " urdf=model.built_from,\n", + " # Create the camera used by the recorder.\n", + " cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\n", + " camera_name=\"cartpole_camera\",\n", + " lookat=js.link.com_position(\n", + " model=model,\n", + " data=data_zero,\n", + " link_index=js.link.name_to_idx(model=model, link_name=\"cart\"),\n", + " in_link_frame=False,\n", + " ),\n", + " distance=3,\n", + " azimut=150,\n", + " elevation=-10,\n", + " ),\n", + ")\n", + "\n", + "# Create a helper to operate on the MuJoCo model and data.\n", + "mj_model_helper = jaxsim.mujoco.MujocoModelHelper.build_from_xml(\n", + " mjcf_description=mjcf_string, assets=assets\n", + ")\n", + "\n", + "# Create the video recorder.\n", + "recorder = jaxsim.mujoco.MujocoVideoRecorder(\n", + " model=mj_model_helper.model,\n", + " data=mj_model_helper.data,\n", + " fps=int(1 / model.time_step),\n", + " width=320 * 2,\n", + " height=240 * 2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DpRvvGujZH4o" + }, + "source": [ + "## Open-loop simulation\n", + "\n", + "Now, let's run a simulation to demonstrate the open-loop dynamics of the system." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gSWzcsKWZH4p" + }, + "outputs": [], + "source": [ + "import mediapy as media\n", + "\n", + "\n", + "# Create a random joint position.\n", + "# For a random full state, you can use jaxsim.api.data.random_model_data.\n", + "random_joint_positions = jax.random.uniform(\n", + " minval=-1.0,\n", + " maxval=1.0,\n", + " shape=(model.dofs(),),\n", + " key=jax.random.PRNGKey(0),\n", + ")\n", + "\n", + "# Reset the state to the random joint positions.\n", + "data = data_zero.reset_joint_positions(positions=random_joint_positions)\n", + "\n", + "\n", + "for _ in T:\n", + "\n", + " # Step the JaxSim simulation.\n", + " data, integrator_state = js.model.step(\n", + " model=model,\n", + " data=data,\n", + " integrator=integrator,\n", + " integrator_state=integrator_state,\n", + " joint_force_references=None,\n", + " link_forces=None,\n", + " )\n", + "\n", + " # Update the MuJoCo data.\n", + " mj_model_helper.set_joint_positions(\n", + " positions=data.joint_positions(), joint_names=model.joint_names()\n", + " )\n", + "\n", + " # Record a new video frame.\n", + " recorder.record_frame(camera_name=\"cartpole_camera\")\n", + "\n", + "\n", + "# Play the video.\n", + "media.show_video(recorder.frames, fps=recorder.fps)\n", + "recorder.frames = []" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j1rguK3UZH4p" + }, + "source": [ + "## Closed-loop simulation\n", + "\n", + "Next, let's design a simple computed torque controller. The equations of motion for the cart-pole system are given by:\n", + "\n", + "$$\n", + "M_{ss}(\\mathbf{s}) \\, \\ddot{\\mathbf{s}} + \\mathbf{h}_s(\\mathbf{s}, \\dot{\\mathbf{s}}) = \\boldsymbol{\\tau}\n", + ",\n", + "$$\n", + "\n", + "where:\n", + "\n", + "- $\\mathbf{s} \\in \\mathbb{R}^n$ are the joint positions.\n", + "- $\\dot{\\mathbf{s}} \\in \\mathbb{R}^n$ are the joint velocities.\n", + "- $\\ddot{\\mathbf{s}} \\in \\mathbb{R}^n$ are the joint accelerations.\n", + "- $\\boldsymbol{\\tau} \\in \\mathbb{R}^n$ are the joint torques.\n", + "- $M_{ss} \\in \\mathbb{R}^{n \\times n}$ is the mass matrix.\n", + "- $\\mathbf{h}_s \\in \\mathbb{R}^n$ is the vector of bias forces.\n", + "\n", + "JaxSim computes these quantities for floating-base systems, so we specifically focus on the joint-related portions by marking them with subscripts.\n", + "\n", + "Since no external forces or joint friction are present, we can extend a PD controller with a feed-forward term that includes gravity compensation:\n", + "\n", + "$$\n", + "\\begin{cases}\n", + "\\boldsymbol{\\tau} &= M_{ss} \\, \\ddot{\\mathbf{s}}^* + \\mathbf{h}_s \\\\\n", + "\\ddot{\\mathbf{s}}^* &= \\ddot{\\mathbf{s}}^\\text{des} - k_p(\\mathbf{s} - \\mathbf{s}^{\\text{des}}) - k_d(\\mathbf{s}^{\\text{des}} - \\dot{\\mathbf{s}}^{\\text{des}})\n", + "\\end{cases}\n", + "\\quad\n", + ",\n", + "$$\n", + "\n", + "where $\\tilde{\\mathbf{s}} = \\left(\\mathbf{s} - \\mathbf{s}^\\text{des}\\right)$ is the joint position error.\n", + "\n", + "With this control law, the closed-loop system dynamics simplifies to:\n", + "\n", + "$$\n", + "\\ddot{\\tilde{\\mathbf{s}}} = -k_p \\tilde{\\mathbf{s}} - k_d \\dot{\\tilde{\\mathbf{s}}}\n", + ",\n", + "$$\n", + "\n", + "which converges asymptotically to zero, ensuring stability." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rfMTCMyGZH4q" + }, + "outputs": [], + "source": [ + "# @title Create the computed torque controller\n", + "\n", + "# Define the PD gains\n", + "kp = 10.0\n", + "kd = 6.0\n", + "\n", + "\n", + "def computed_torque_controller(\n", + " data: js.data.JaxSimModelData,\n", + " s_des: jax.Array,\n", + " s_dot_des: jax.Array,\n", + ") -> jax.Array:\n", + "\n", + " # Compute the gravity compensation term.\n", + " hs = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n", + "\n", + " # Compute the joint-related portion of the floating-base mass matrix.\n", + " Mss = js.model.free_floating_mass_matrix(model=model, data=data)[6:, 6:]\n", + "\n", + " # Get the current joint positions and velocities.\n", + " s = data.joint_positions()\n", + " ṡ = data.joint_velocities()\n", + "\n", + " # Compute the actuated joint torques.\n", + " s_star = - kp * (s - s_des) - kd * (ṡ - s_dot_des)\n", + " τ = Mss @ s_star + hs\n", + "\n", + " return τ" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ERAUisywZH4q" + }, + "source": [ + "Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8YmDdGDVZH4q" + }, + "outputs": [], + "source": [ + "# @title Run the simulation\n", + "\n", + "# Initialize the data.\n", + "data = data_zero.reset_joint_positions(\n", + " model=model,\n", + " positions=jnp.array([-0.25, jnp.deg2rad(160)]),\n", + ").reset_joint_velocities(\n", + " model=model,\n", + " velocities=jnp.array([3.00, jnp.deg2rad(10) / model.time_step]),\n", + ")\n", + "\n", + "for _ in T:\n", + "\n", + " # Get the actuated torques from the computed torque controller.\n", + " τ = computed_torque_controller(\n", + " data=data,\n", + " s_des=jnp.array([0.0, 0.0]),\n", + " s_dot_des=jnp.array([0.0, 0.0]),\n", + " )\n", + "\n", + " # Step the JaxSim simulation.\n", + " data, integrator_state = js.model.step(\n", + " model=model,\n", + " data=data,\n", + " integrator=integrator,\n", + " integrator_state=integrator_state,\n", + " joint_force_references=τ,\n", + " )\n", + "\n", + " # Update the MuJoCo data.\n", + " mj_model_helper.set_joint_positions(\n", + " positions=data.joint_positions(), joint_names=model.joint_names()\n", + " )\n", + "\n", + " # Record a new video frame.\n", + " recorder.record_frame(camera_name=\"cartpole_camera\")\n", + "\n", + "media.show_video(recorder.frames, fps=recorder.fps)\n", + "recorder.frames = []" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sZ76QqeWeMQz" + }, + "source": [ + "## Conclusions\n", + "\n", + "In this notebook, we explored how to use JaxSim for developing a closed-loop controller for a robot model. Key takeaways include:\n", + "\n", + "- We performed an open-loop simulation to understand the dynamics of the system without control.\n", + "- We implemented a computed torque controller with PD feedback and a feed-forward gravity compensation term, enabling the stabilization of the system by controlling joint torques.\n", + "- The closed-loop simulation can leverage hardware acceleration on GPUs and TPUs, with the ability to use `jax.vmap` for parallel sampling through automatic vectorization.\n", + "\n", + "JaxSim's closed-loop support can be extended to more advanced, model-based reactive controllers and planners for trajectory optimization. To explore optimization-based methods, consider the following JAX-based projects for hardware-accelerated control and planning:\n", + "\n", + "- [`deepmind/optax`](https://github.com/deepmind/optax)\n", + "- [`google/jaxopt`](https://github.com/google/jaxopt)\n", + "- [`patrick-kidger/lineax`](https://github.com/patrick-kidger/lineax)\n", + "- [`patrick-kidger/optimistix`](https://github.com/patrick-kidger/optimistix)\n", + "- [`kevin-tracy/qpax`](https://github.com/kevin-tracy/qpax)\n", + "\n", + "Additionally, if your controllers or planners require the derivatives of the dynamics with respect to the state or inputs, you can obtain them using automatic differentiation directly through JaxSim's API." + ] + } + ], + "metadata": { + "colab": { + "gpuClass": "premium", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 234c3772bb5b7e27ad7897e9452996c341b31fe7 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 16 Oct 2024 14:34:54 +0200 Subject: [PATCH 4/6] Update examples readme --- examples/README.md | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/examples/README.md b/examples/README.md index 1d0ae20ef..b278bd8e8 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,40 +1,41 @@ -# JAXsim Notebook Examples +# JaxSim Examples -This folder includes a Jupyter Notebook demonstrating the practical usage of JAXsim for system simulations. +This folder contains Jupyter notebooks that demonstrate the practical usage of JaxSim. -### Examples +## Featured examples -- [PD_controller](./PD_controller.ipynb) - Open In Colab - A simple example demonstrating the use of JAXsim to simulate a PD controller with gravity compensation for a 2-DOF cartpole. +| Notebook | Google Colab | Description | +| :--- | :---: | :--- | +| [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. | +| [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. | -- [Parallel_computing](./Parallel_computing.ipynb) - Open In Colab - An example demonstrating how to simulate vectorized models in parallel using JAXsim. +[colab_badge]: https://colab.research.google.com/assets/colab-badge.svg +[ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb +[ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb -> [!TIP] -> Stay tuned for more examples! +## How to run the examples -## Running the Examples +You can run the JaxSim examples with hardware acceleration in two ways. -To execute these examples utilizing JAXsim with hardware acceleration, there are a couple of options available: +### Option 1: Google Colab (recommended) -### Option 1: Google Colab (Recommended) +The easiest way is to use the provided Google Colab links to run the notebooks in a hosted environment +with no setup required. -The simplest way to run the examples is by accessing the provided Google Colab notebook link mentioned above. This will enable you to execute the examples in a hosted environment. +### Option 2: Local execution with `pixi` -### Option 2: Local Execution with `pixi` +To run the examples locally, first install `pixi` following the [official documentation][pixi_installation]: -For local execution, follow these steps: - -1. **Install `pixi`:** - -As per the [official documentation](https://pixi.sh/#installation): +[pixi_installation]: https://pixi.sh/#installation ```bash curl -fsSL https://pixi.sh/install.sh | bash ``` -2. **Run the Example Notebook:** +Then, from the repository's root directory, execute the example notebooks using: -Use `pixi run examples` from the project source directory to execute the example notebook locally. +```bash +pixi run examples +``` -This command will automatically handle the installation of necessary dependencies and execute the examples within a self-contained environment +This command will automatically handle all necessary dependencies and run the examples in a self-contained environment. From 1bf67ef7bcd2d3d65d2014c32c768f2928137afa Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 17 Oct 2024 13:09:11 +0200 Subject: [PATCH 5/6] Update website --- docs/examples.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/examples.rst b/docs/examples.rst index 3c6685650..dfea9a2a4 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -16,29 +16,29 @@ Example Notebooks .. raw:: html -
+
.. only:: html - :doc:`_collections/examples/Parallel_computing` + :doc:`_collections/examples/jaxsim_as_physics_engine` .. raw:: html -
Parallel Simulation
+
JaxSim as a hardware-accelerated parallel physics engine
.. raw:: html -
+
.. only:: html - :doc:`_collections/examples/PD_controller` + :doc:`_collections/examples/jaxsim_for_robot_controllers` .. raw:: html -
PD Controller
+
JaxSim for developing closed-loop robot controllers
.. raw:: html From bdf5c8ea6d9d6bd7790e3e6d93d8f061eececf57 Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 17 Oct 2024 15:19:57 +0200 Subject: [PATCH 6/6] Apply suggestions from code review Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- examples/jaxsim_as_physics_engine.ipynb | 2 +- examples/jaxsim_for_robot_controllers.ipynb | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/jaxsim_as_physics_engine.ipynb b/examples/jaxsim_as_physics_engine.ipynb index b67331c8b..3c306b730 100644 --- a/examples/jaxsim_as_physics_engine.ipynb +++ b/examples/jaxsim_as_physics_engine.ipynb @@ -122,7 +122,7 @@ "print(model_sdf_string)\n", "\n", "# JaxSim currently only supports collisions between points attached to bodies\n", - "# and a ground surface modeled as a heighmap sampled from a smooth function.\n", + "# and a ground surface modeled as a heightmap sampled from a smooth function.\n", "# While this approach is universal as it applies to generic meshes, the number\n", "# of considered points greatly affects the performance. Spheres, by default,\n", "# are discretized with 250 points. It's too much for this simple example.\n", diff --git a/examples/jaxsim_for_robot_controllers.ipynb b/examples/jaxsim_for_robot_controllers.ipynb index 6bb5ea4ae..bc1c10ebb 100644 --- a/examples/jaxsim_for_robot_controllers.ipynb +++ b/examples/jaxsim_for_robot_controllers.ipynb @@ -41,7 +41,7 @@ " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", "\n", " # Install dependencies for visualization on Colab and ReadTheDocs.\n", - " !sudo apt update\n", + " !apt -qq update\n", " !apt install libosmesa6-dev\n", " clear_output()\n", "\n", @@ -385,10 +385,14 @@ "# @title Run the simulation\n", "\n", "# Initialize the data.\n", + "# Set the joint positions.\n", "data = data_zero.reset_joint_positions(\n", " model=model,\n", " positions=jnp.array([-0.25, jnp.deg2rad(160)]),\n", - ").reset_joint_velocities(\n", + ")\n", + "\n", + "# Set the joint velocities.\n", + "data = data.reset_joint_velocities(\n", " model=model,\n", " velocities=jnp.array([3.00, jnp.deg2rad(10) / model.time_step]),\n", ")\n", @@ -439,7 +443,7 @@ "\n", "JaxSim's closed-loop support can be extended to more advanced, model-based reactive controllers and planners for trajectory optimization. To explore optimization-based methods, consider the following JAX-based projects for hardware-accelerated control and planning:\n", "\n", - "- [`deepmind/optax`](https://github.com/deepmind/optax)\n", + "- [`deepmind/optax`](https://github.com/google-deepmind/optax)\n", "- [`google/jaxopt`](https://github.com/google/jaxopt)\n", "- [`patrick-kidger/lineax`](https://github.com/patrick-kidger/lineax)\n", "- [`patrick-kidger/optimistix`](https://github.com/patrick-kidger/optimistix)\n",