From 2985e3a829044a98286e8d3c4d0ac0e36d18ab6f Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 17 Jan 2024 12:12:18 +0000 Subject: [PATCH 1/5] Amend --- torchrl/envs/transforms/transforms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 21bb542cb1d..61bbcec549b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -572,7 +572,9 @@ def __init__( env = env.to(device) else: device = env.device - super().__init__(device=None, **kwargs) + super().__init__( + device=None, allow_done_after_reset=env._allow_done_after_reset, **kwargs + ) if isinstance(env, TransformedEnv): self._set_env(env.base_env, device) From d33e34839a9202ada9860090dd1fc554694be388 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 17 Jan 2024 14:17:37 +0000 Subject: [PATCH 2/5] Amend --- test/mocking_classes.py | 1 + test/test_transforms.py | 15 +++++++++++++++ torchrl/envs/common.py | 12 ++++++++++-- torchrl/envs/transforms/transforms.py | 14 +++++++++++--- 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 5dd855d65e2..ee04d949789 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -111,6 +111,7 @@ def __init__( super().__init__( device=kwargs.pop("device", "cpu"), dtype=torch.get_default_dtype(), + **kwargs, ) self.set_seed(seed) self.is_closed = False diff --git a/test/test_transforms.py b/test/test_transforms.py index 3ef633eee98..fa62cf37774 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7691,6 +7691,21 @@ def test_independent_reward_specs_from_shared_env(self): assert base_env.reward_spec.space.minimum == -np.inf assert base_env.reward_spec.space.maximum == np.inf + def test_allow_done_after_reset(self): + base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True) + assert base_env.allow_done_after_reset + t1 = TransformedEnv( + base_env, transform=RewardClipping(clamp_min=0, clamp_max=4) + ) + assert t1.allow_done_after_reset + with pytest.raises( + RuntimeError, + match="allow_done_after_reset is a read-only property for TransformedEnvs", + ): + t1.allow_done_after_reset = False + base_env.allow_done_after_reset = False + assert not t1.allow_done_after_reset + def test_nested_transformed_env(): base_env = ContinuousActionVecMockEnv() diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index eda8c859692..a236481132f 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -334,6 +334,14 @@ def run_type_checks(self) -> bool: def run_type_checks(self, run_type_checks: bool) -> None: self._run_type_checks = run_type_checks + @property + def allow_done_after_reset(self) -> bool: + return self._allow_done_after_reset + + @allow_done_after_reset.setter + def allow_done_after_reset(self, allow_done_after_reset: bool): + self._allow_done_after_reset = allow_done_after_reset + @property def batch_size(self) -> torch.Size: _batch_size = self.__dict__["_batch_size"] @@ -1522,7 +1530,7 @@ def _reset_check_done(self, tensordict, tensordict_reset): if reset_value is not None: for done_key in done_key_group: done_val = tensordict_reset.get(done_key) - if done_val[reset_value].any() and not self._allow_done_after_reset: + if done_val[reset_value].any() and not self.allow_done_after_reset: raise RuntimeError( f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed." ) @@ -1540,7 +1548,7 @@ def _reset_check_done(self, tensordict, tensordict_reset): # we set the done val to tensordict, to make sure that # _update_during_reset does not pad the value tensordict.set(done_key, done_val) - elif not self._allow_done_after_reset: + elif not self.allow_done_after_reset: for done_key in done_key_group: if tensordict_reset.get(done_key).any(): raise RuntimeError( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 61bbcec549b..d5a1a0cd99b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -572,9 +572,7 @@ def __init__( env = env.to(device) else: device = env.device - super().__init__( - device=None, allow_done_after_reset=env._allow_done_after_reset, **kwargs - ) + super().__init__(device=None, **kwargs) if isinstance(env, TransformedEnv): self._set_env(env.base_env, device) @@ -681,6 +679,16 @@ def run_type_checks(self, value): "run_type_checks is a read-only property for TransformedEnvs" ) + @property + def allow_done_after_reset(self) -> bool: + return self.base_env.allow_done_after_reset + + @allow_done_after_reset.setter + def allow_done_after_reset(self, value): + raise RuntimeError( + "allow_done_after_reset is a read-only property for TransformedEnvs" + ) + @property def _inplace_update(self): return self.base_env._inplace_update From 92cad1602d452b41e61223ceda63771dff72dfd9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 17 Jan 2024 14:36:00 +0000 Subject: [PATCH 3/5] Amend --- test/mocking_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index ee04d949789..a113d1f67d7 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -111,7 +111,7 @@ def __init__( super().__init__( device=kwargs.pop("device", "cpu"), dtype=torch.get_default_dtype(), - **kwargs, + allow_done_after_reset=kwargs.pop("allow_done_after_reset", False), ) self.set_seed(seed) self.is_closed = False From 73c0fb9bfdf542fbac5688fc19964b0d126becde Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 17 Jan 2024 20:59:50 +0000 Subject: [PATCH 4/5] empty From 76073b98de32756c0ce6ebab226ff899678a5d50 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 17 Jan 2024 21:13:06 +0000 Subject: [PATCH 5/5] Amend --- test/test_transforms.py | 12 ++++++------ torchrl/envs/common.py | 12 ++---------- torchrl/envs/transforms/transforms.py | 14 ++++++++------ 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 029dfe8b5bb..1f419e0426a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7693,18 +7693,18 @@ def test_independent_reward_specs_from_shared_env(self): def test_allow_done_after_reset(self): base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True) - assert base_env.allow_done_after_reset + assert base_env._allow_done_after_reset t1 = TransformedEnv( base_env, transform=RewardClipping(clamp_min=0, clamp_max=4) ) - assert t1.allow_done_after_reset + assert t1._allow_done_after_reset with pytest.raises( RuntimeError, - match="allow_done_after_reset is a read-only property for TransformedEnvs", + match="_allow_done_after_reset is a read-only property for TransformedEnvs", ): - t1.allow_done_after_reset = False - base_env.allow_done_after_reset = False - assert not t1.allow_done_after_reset + t1._allow_done_after_reset = False + base_env._allow_done_after_reset = False + assert not t1._allow_done_after_reset def test_nested_transformed_env(): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 521d6740aa6..633ac2f78a3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -334,14 +334,6 @@ def run_type_checks(self) -> bool: def run_type_checks(self, run_type_checks: bool) -> None: self._run_type_checks = run_type_checks - @property - def allow_done_after_reset(self) -> bool: - return self._allow_done_after_reset - - @allow_done_after_reset.setter - def allow_done_after_reset(self, allow_done_after_reset: bool): - self._allow_done_after_reset = allow_done_after_reset - @property def batch_size(self) -> torch.Size: _batch_size = self.__dict__["_batch_size"] @@ -1530,7 +1522,7 @@ def _reset_check_done(self, tensordict, tensordict_reset): if reset_value is not None: for done_key in done_key_group: done_val = tensordict_reset.get(done_key) - if done_val[reset_value].any() and not self.allow_done_after_reset: + if done_val[reset_value].any() and not self._allow_done_after_reset: raise RuntimeError( f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed." ) @@ -1548,7 +1540,7 @@ def _reset_check_done(self, tensordict, tensordict_reset): # we set the done val to tensordict, to make sure that # _update_during_reset does not pad the value tensordict.set(done_key, done_val) - elif not self.allow_done_after_reset: + elif not self._allow_done_after_reset: for done_key in done_key_group: if tensordict_reset.get(done_key).any(): raise RuntimeError( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 98e991d2a0d..d36863dae3e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -572,7 +572,7 @@ def __init__( env = env.to(device) else: device = env.device - super().__init__(device=None, **kwargs) + super().__init__(device=None, allow_done_after_reset=None, **kwargs) if isinstance(env, TransformedEnv): self._set_env(env.base_env, device) @@ -680,13 +680,15 @@ def run_type_checks(self, value): ) @property - def allow_done_after_reset(self) -> bool: - return self.base_env.allow_done_after_reset + def _allow_done_after_reset(self) -> bool: + return self.base_env._allow_done_after_reset - @allow_done_after_reset.setter - def allow_done_after_reset(self, value): + @_allow_done_after_reset.setter + def _allow_done_after_reset(self, value): + if value is None: + return raise RuntimeError( - "allow_done_after_reset is a read-only property for TransformedEnvs" + "_allow_done_after_reset is a read-only property for TransformedEnvs" ) @property