Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 21, 2023
1 parent de1ecf2 commit 6fec99e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7412,7 +7412,8 @@ def test_onlinedt_tensordict_keys(self):
loss_fn = OnlineDTLoss(actor)

default_keys = {
"action": "action",
"action_pred": "action",
"action_target": "action",
}

self.tensordict_keys_test(
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4737,12 +4737,14 @@ def reset_keys(self):
def _check_match(reset_keys, in_keys):
# if this is called, the length of reset_keys and in_keys must match
for reset_key, in_key in zip(reset_keys, in_keys):
if isinstance(reset_key, str) ^ isinstance(in_key, str):
# having _reset at the root and the reward_key ("agent", "reward") is allowed
# but having ("agent", "_reset") and "reward" isn't
if isinstance(reset_key, tuple) and isinstance(in_key, str):
return False
if (
isinstance(reset_key, tuple)
and isinstance(in_key, tuple)
and reset_key[:-1] != in_key[:-1]
and in_key[:(len(reset_key)-1)] != reset_key[:-1]
):
return False
return True
Expand Down

0 comments on commit 6fec99e

Please sign in to comment.