Skip to content

Commit

Permalink
Replace for-loop in extract_subsequences with single indexing operation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 468706178
  • Loading branch information
RLaxDev authored and RLaxDev committed Aug 22, 2022
1 parent 31717e6 commit c792f9b
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions rlax/_src/model_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,13 @@ def extract_subsequences(
if traj_len < min_len:
raise AssertionError(
f'Expected len >= {min_len}, but trajectories length is: {traj_len}.')

batch_size = start_indices.shape[0]
batch_range = jnp.arange(batch_size)
num_subs = start_indices.shape[1]
slices = []
for i in range(num_subs):
slices.append(jnp.stack(
[trajectories[start_indices[:, i] + k, batch_range]
for k in range(subsequence_len)], axis=0))
return jnp.stack(slices, axis=2)
idx_arr = jnp.arange(subsequence_len)[:, None, None] * jnp.ones(
(subsequence_len, batch_size, num_subs), dtype=jnp.int32) + start_indices
return trajectories[idx_arr, batch_range[None, :, None], ...]


def sample_start_indices(
Expand Down

0 comments on commit c792f9b

Please sign in to comment.