Skip to content

Commit

Permalink
Use plain integers for frame indices
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed May 22, 2024
1 parent d765203 commit a60b683
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
# =======================


@functools.partial(jax.jit, static_argnames=["frame_idx"])
def idx_of_parent_link(
model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike
) -> jtp.Int:
def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -> int:
"""
Get the index of the link to which the frame is rigidly attached.
Expand All @@ -39,11 +36,10 @@ def idx_of_parent_link(
F = ir.frames[frame_idx - model.number_of_links()]
L = ir.links_dict[F.parent.name].index

return jnp.array(L).astype(int)
return int(L)


@functools.partial(jax.jit, static_argnames=["frame_name"])
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
"""
Convert the name of a frame to its index.
Expand All @@ -59,8 +55,9 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:

if frame_name in frame_names:
idx_in_list = np.argwhere(frame_names == frame_name)
return jnp.array(idx_in_list + model.number_of_links()).squeeze().astype(int)
return jnp.array(-1).astype(int)
return int(idx_in_list.squeeze().tolist()) + model.number_of_links()

return -1


def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
Expand Down

0 comments on commit a60b683

Please sign in to comment.