From 2460112b67ff25e9a791d84a906595a9a5d2095a Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Fri, 3 Jun 2022 14:35:19 -0700 Subject: [PATCH] Fix max_start_idx argument. PiperOrigin-RevId: 452845238 --- rlax/_src/model_learning.py | 8 ++++---- rlax/_src/model_learning_test.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rlax/_src/model_learning.py b/rlax/_src/model_learning.py index 4775d11..3474d4c 100644 --- a/rlax/_src/model_learning.py +++ b/rlax/_src/model_learning.py @@ -37,8 +37,8 @@ def extract_subsequences( trajectories: A batch of trajectories, shape `[T, B, ...]`. start_indices: Time indices of start points, shape `[B, num_start_indices]`. subsequence_len: The length of subsequences extracted from `trajectories`. - max_valid_start_idx: The window used to construct the `start_idx`: i.e. the - `start_indices` should be from {0, ..., max_valid_start_idx - 1}. + max_valid_start_idx: the maximum valid start index, therefore the + `start_indices` should be from {0, ..., max_valid_start_idx}. Returns: A batch of subsequences, with @@ -46,7 +46,7 @@ def extract_subsequences( index. Output shape is: `[subsequence_len, B, num_start_indices, ...]`. """ if max_valid_start_idx is not None: - min_len = max_valid_start_idx + subsequence_len - 1 + min_len = max_valid_start_idx + subsequence_len traj_len = trajectories.shape[0] if traj_len < min_len: raise AssertionError( @@ -85,7 +85,7 @@ def _vchoose(key, entries): return jax.random.choice( key, entries, shape=(num_start_indices,), replace=False) - rollout_window = jnp.arange(max_valid_start_idx) + rollout_window = jnp.arange(max_valid_start_idx + 1) return _vchoose( jax.random.split(rng_key, batch_size), jnp.tile(rollout_window, (batch_size, 1))) diff --git a/rlax/_src/model_learning_test.py b/rlax/_src/model_learning_test.py index 95750dc..f8fcabe 100644 --- a/rlax/_src/model_learning_test.py +++ b/rlax/_src/model_learning_test.py @@ -53,7 +53,7 @@ def test_extract_subsequences_with_validation_bounds(self): with self.assertRaisesRegex(AssertionError, 'Expected len >='): model_learning.extract_subsequences( self.trajectories, self.invalid_start_indices, 1, - max_valid_start_idx=26) + max_valid_start_idx=24) if __name__ == '__main__':