Skip to content

Commit

Permalink
Fix max_start_idx argument.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452845238
  • Loading branch information
mtthss authored and RLaxDev committed Jun 20, 2022
1 parent ed04afc commit 2460112
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions rlax/_src/model_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ def extract_subsequences(
trajectories: A batch of trajectories, shape `[T, B, ...]`.
start_indices: Time indices of start points, shape `[B, num_start_indices]`.
subsequence_len: The length of subsequences extracted from `trajectories`.
max_valid_start_idx: The window used to construct the `start_idx`: i.e. the
`start_indices` should be from {0, ..., max_valid_start_idx - 1}.
max_valid_start_idx: the maximum valid start index, therefore the
`start_indices` should be from {0, ..., max_valid_start_idx}.
Returns:
A batch of subsequences, with
`trajectories[start_indices[i, j]:start_indices[i, j] + n]` for each start
index. Output shape is: `[subsequence_len, B, num_start_indices, ...]`.
"""
if max_valid_start_idx is not None:
min_len = max_valid_start_idx + subsequence_len - 1
min_len = max_valid_start_idx + subsequence_len
traj_len = trajectories.shape[0]
if traj_len < min_len:
raise AssertionError(
Expand Down Expand Up @@ -85,7 +85,7 @@ def _vchoose(key, entries):
return jax.random.choice(
key, entries, shape=(num_start_indices,), replace=False)

rollout_window = jnp.arange(max_valid_start_idx)
rollout_window = jnp.arange(max_valid_start_idx + 1)
return _vchoose(
jax.random.split(rng_key, batch_size),
jnp.tile(rollout_window, (batch_size, 1)))
2 changes: 1 addition & 1 deletion rlax/_src/model_learning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_extract_subsequences_with_validation_bounds(self):
with self.assertRaisesRegex(AssertionError, 'Expected len >='):
model_learning.extract_subsequences(
self.trajectories, self.invalid_start_indices, 1,
max_valid_start_idx=26)
max_valid_start_idx=24)


if __name__ == '__main__':
Expand Down

0 comments on commit 2460112

Please sign in to comment.