diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 31f681b8a48..fc5feb487b9 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -111,6 +111,7 @@ def __init__( super().__init__( device=kwargs.pop("device", "cpu"), dtype=torch.get_default_dtype(), + allow_done_after_reset=kwargs.pop("allow_done_after_reset", False), ) self.set_seed(seed) self.is_closed = False diff --git a/test/test_transforms.py b/test/test_transforms.py index 6340dec842e..1f419e0426a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7691,6 +7691,21 @@ def test_independent_reward_specs_from_shared_env(self): assert base_env.reward_spec.space.minimum == -np.inf assert base_env.reward_spec.space.maximum == np.inf + def test_allow_done_after_reset(self): + base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True) + assert base_env._allow_done_after_reset + t1 = TransformedEnv( + base_env, transform=RewardClipping(clamp_min=0, clamp_max=4) + ) + assert t1._allow_done_after_reset + with pytest.raises( + RuntimeError, + match="_allow_done_after_reset is a read-only property for TransformedEnvs", + ): + t1._allow_done_after_reset = False + base_env._allow_done_after_reset = False + assert not t1._allow_done_after_reset + def test_nested_transformed_env(): base_env = ContinuousActionVecMockEnv() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6f115ec118e..d36863dae3e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -572,7 +572,7 @@ def __init__( env = env.to(device) else: device = env.device - super().__init__(device=None, **kwargs) + super().__init__(device=None, allow_done_after_reset=None, **kwargs) if isinstance(env, TransformedEnv): self._set_env(env.base_env, device) @@ -679,6 +679,18 @@ def run_type_checks(self, value): "run_type_checks is a read-only property for TransformedEnvs" ) + @property + def _allow_done_after_reset(self) -> bool: + return self.base_env._allow_done_after_reset + + @_allow_done_after_reset.setter + def _allow_done_after_reset(self, value): + if value is None: + return + raise RuntimeError( + "_allow_done_after_reset is a read-only property for TransformedEnvs" + ) + @property def _inplace_update(self): return self.base_env._inplace_update