Skip to content

Commit

Permalink
[BugFix] Fix Compose input spec transform (#2463)
Browse files Browse the repository at this point in the history
Co-authored-by: Louis Faury <[email protected]>
Co-authored-by: Vincent Moens <[email protected]>
  • Loading branch information
3 people authored Oct 4, 2024
1 parent 97ccbb7 commit b116151
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
29 changes: 29 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8675,6 +8675,35 @@ def test_compose_indexing(self):
assert last_t.scale == 4
assert last_t2.scale == 4

def test_compose_action_spec(self):
# Create a Compose transform that renames "action" to "action_1" and then to "action_2"
c = Compose(
RenameTransform(
in_keys=(),
out_keys=(),
in_keys_inv=("action",),
out_keys_inv=("action_1",),
),
RenameTransform(
in_keys=(),
out_keys=(),
in_keys_inv=("action_1",),
out_keys_inv=("action_2",),
),
)
base_env = ContinuousActionVecMockEnv()
env = TransformedEnv(base_env, c)

# Check the `full_action_spec`s
assert "action_2" in env.full_action_spec
# Ensure intermediate keys are no longer in the action spec
assert "action_1" not in env.full_action_spec
assert "action" not in env.full_action_spec

# Final check to ensure clean sampling from the action_spec
action = env.rand_action()
assert "action_2"

@pytest.mark.parametrize("device", get_default_devices())
def test_finitetensordictcheck(self, device):
ftd = FiniteTensorDictCheck()
Expand Down
29 changes: 15 additions & 14 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,21 +1675,22 @@ def __init__(
self.cat_results = cat_results

def _check_replay_buffer_init(self):
try:
if not self.replay_buffer._storage.initialized:
if isinstance(self.create_env_fn, EnvCreator):
fake_td = self.create_env_fn.tensordict
elif isinstance(self.create_env_fn, EnvBase):
fake_td = self.create_env_fn.fake_tensordict()
else:
fake_td = self.create_env_fn[0](
**self.create_env_kwargs[0]
).fake_tensordict()
fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long)
is_init = getattr(self.replay_buffer._storage, "initialized", True)
if not is_init:
if isinstance(self.create_env_fn[0], EnvCreator):
fake_td = self.create_env_fn[0].tensordict
elif isinstance(self.create_env_fn[0], EnvBase):
fake_td = self.create_env_fn[0].fake_tensordict()
else:
fake_td = self.create_env_fn[0](
**self.create_env_kwargs[0]
).fake_tensordict()
fake_td["collector", "traj_ids"] = torch.zeros(
fake_td.shape, dtype=torch.long
)

self.replay_buffer._storage._init(fake_td)
except AttributeError:
pass
self.replay_buffer.add(fake_td)
self.replay_buffer.empty()

@classmethod
def _total_workers_from_env(cls, env_creators):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def transform_env_batch_size(self, batch_size: torch.batch_size):
return batch_size

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
for t in self.transforms[::-1]:
for t in self.transforms:
input_spec = t.transform_input_spec(input_spec)
return input_spec

Expand Down

0 comments on commit b116151

Please sign in to comment.