Skip to content

Commit

Permalink
Feat/pendulum rollout example (#9)
Browse files Browse the repository at this point in the history
* added pendulum rollout as example

* fixed README
  • Loading branch information
simeon-ned authored and mattephi committed Oct 20, 2024
1 parent 05c60f9 commit 7153ac8
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 4 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,18 @@ JAXADI comes with several examples to help you get started:

3. [Function Conversion](examples/02_convert.py): See how to fully convert CasADi functions to JAX.

4. [Pinocchio Integration](examples/03_pinocchio.py): Explore how to convert Pinocchio-based CasADi functions to JAX.
4. [Pendulum Rollout](examples/03_pendulum_rollout.py): Batched rollout of the nonlinear passive nonlinear pendulum

5. [Pinocchio Integration](examples/04_pinocchio.py): Explore how to convert Pinocchio-based CasADi functions to JAX.

> **Note**: To run the Pinocchio example, ensure you have Pinocchio properly installed in your environment.
6. [MJX Comparison](examples/05_mjx.py): Compare the transformed Pinnocchio forward kinematics with one provided by Mujoco MJX

> **Note**: To run the Pinocchio and MJX examples, ensure you have them properly installed in your environment.
## Performance Benchmarks

(Consider adding a section about performance comparisons between CasADi and JAXADI-translated functions)
<!-- ## Performance Benchmarks
(Consider adding a section about performance comparisons between CasADi and JAXADI-translated functions) -->

<!-- ## Contributing
Expand Down
101 changes: 101 additions & 0 deletions examples/03_pendulum_rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import timeit
import casadi as ca
import jax
import jax.numpy as jnp
import numpy as np
from jaxadi import convert

# Static parameters
dt = 0.02
g = 9.81 # Acceleration due to gravity
L = 1.0 # Length of the pendulum
b = 0.1 # Damping coefficient
I = 1.0 # Moment of inertia
# Test parameters
batch_size = 4096
timesteps = 100


# Define the uncontrolled pendulum model in CasADi
def casadi_pendulum_model():
state = ca.SX.sym("state", 2)
theta, omega = state[0], state[1]

theta_dot = omega
omega_dot = (-b * omega - (g / L) * ca.sin(theta)) / I

next_theta = theta + theta_dot * dt
next_omega = omega + omega_dot * dt

next_state = ca.vertcat(next_theta, next_omega)
return ca.Function("pendulum_model", [state], [next_state])


# Create CasADi function
casadi_model = casadi_pendulum_model()

# Convert CasADi function to JAX
jax_model = convert(casadi_model, compile=True)


# Function to generate random inputs
def generate_random_inputs(batch_size):
return np.random.uniform(-np.pi, np.pi, (batch_size, 2))


# CasADi: Sequential Evaluation
def casadi_sequential_rollout(initial_states):
batch_size = initial_states.shape[0]
rollout_states = np.zeros((timesteps + 1, batch_size, 2))

rollout_states[0] = initial_states
for t in range(1, timesteps + 1):
rollout_states[t] = np.array([casadi_model(state).full().flatten() for state in rollout_states[t - 1]])

return rollout_states


# JAX: Vectorized Evaluation
@jax.jit
def jax_vectorized_rollout(initial_states):
def single_step(state):
return jnp.array(jax_model(state)).reshape(
2,
)

def scan_fn(carry, _):
next_state = jax.vmap(single_step)(carry)
return next_state, next_state

_, rollout_states = jax.lax.scan(scan_fn, initial_states, None, length=timesteps)
return jnp.concatenate([initial_states[None, ...], rollout_states], axis=0)


# Generate random inputs
initial_states = generate_random_inputs(batch_size)

# Warm-up call for JAX
print("Performing warm-up call for JAX...")
_ = jax_vectorized_rollout(initial_states)
print("Warm-up call completed.")
# Performance comparison
print("\nPerformance comparison:")
# Generate new random inputs
initial_states = generate_random_inputs(batch_size)

print(f"CasADi sequential rollout ({batch_size} pendulums, {timesteps} timesteps):")
casadi_time = timeit.timeit(lambda: casadi_sequential_rollout(initial_states), number=1)
print(f"Time: {casadi_time:.4f} seconds")

print(f"\nJAX vectorized rollout ({batch_size} pendulums, {timesteps} timesteps):")
jax_time = timeit.timeit(lambda: np.array(jax_vectorized_rollout(initial_states)), number=1)
print(f"Time: {jax_time:.4f} seconds")

print(f"\nSpeedup factor: {casadi_time / jax_time:.2f}x")

# Verify results
print("\nVerifying results:")
casadi_results = casadi_sequential_rollout(initial_states[:10])
jax_results = np.array(jax_vectorized_rollout(initial_states[:10]))

print("First 10 rollouts match:", np.allclose(casadi_results, jax_results, atol=1e-4))
File renamed without changes.
File renamed without changes.

0 comments on commit 7153ac8

Please sign in to comment.