Skip to content

Commit

Permalink
feat: beat casadi
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Sep 3, 2024
1 parent a2170c5 commit e1110ea
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions examples/03_pinocchio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import timeit

import jax
import casadi as ca
import jax.numpy as jnp
import pinocchio as pin
Expand Down Expand Up @@ -33,7 +33,8 @@

# Evaluate the function performance
q_val = ca.np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0, 0])
jax_q_val = jnp.array([[0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0], [0]])
jax_q_val = jnp.array([[0.1], [0.2], [0.3], [0.4], [
0.5], [0.6], [0.7], [0], [0]])

print("Casadi evaluation:")
print(fk(q_val))
Expand All @@ -44,6 +45,29 @@
# print("Performance comparison:")
# print("Casadi evaluation:")
# print(timeit.timeit(lambda: fk(q_val), number=100))

#
# print("JAX evaluation:")
# print(timeit.timeit(lambda: jax_fn(jax_q_val), number=100))


# Second part
# Casadi: Sequential Evaluation
N = int(1e7)
def casadi_sequential_evaluation():
for _ in range(N):
fk(q_val)


# JAX: Vectorized Evaluation using vmap
jax_q_vals = jnp.tile(jax_q_val, (N, 1, 1)) # Create a batch of 100 inputs
print(jax_q_vals.shape)
jax_fn_vectorized = jax.vmap(jax_fn, in_axes=(
1,), out_axes=1) # Vectorize the function

# Performance comparison
print("Performance comparison:")
print(f"Casadi sequential evaluation ({N} times):")
print(timeit.timeit(casadi_sequential_evaluation, number=1))

print("JAX vectorized evaluation using vmap:")
print(timeit.timeit(lambda: jax_fn_vectorized(jax_q_vals), number=1))

0 comments on commit e1110ea

Please sign in to comment.