Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] exclude adds keys to the resulted tensordict #636

Closed
3 tasks done
btx0424 opened this issue Jan 24, 2024 · 1 comment · Fixed by #637
Closed
3 tasks done

[BUG] exclude adds keys to the resulted tensordict #636

btx0424 opened this issue Jan 24, 2024 · 1 comment · Fixed by #637
Assignees
Labels
bug Something isn't working

Comments

@btx0424
Copy link

btx0424 commented Jan 24, 2024

Describe the bug

tensordict.exclude unexpectedly adds a key to exclude to the result if the key is a nested key.

To Reproduce

Running

import torch
from tensordict import TensorDict

next_td = TensorDict({
    'a': {
        'b': torch.zeros(64, 1),
        'c': torch.zeros(64, 1),
    },
}, [64])
next_preset = TensorDict({
    "d": torch.zeros(64, 3),
}, [64])
td = next_preset.exclude(("a", "b"))

print(td.keys(True, True))
print(td)

will raise an exception because "a" somehow became a key of td.

Importantly, it causes error for

# TransformedEnv
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
    tensordict = tensordict.clone(False)
    next_preset = tensordict.get("next", None)
    tensordict_in = self.transform.inv(tensordict)
    next_tensordict = self.base_env._step(tensordict_in)
    if next_preset is not None:
        # tensordict could already have a "next" key
        # this could be done more efficiently by not excluding but just passing
        # the necessary keys
        next_tensordict.update(
            next_preset.exclude(*next_tensordict.keys(True, True))
        )
    self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict)
    # we want the input entries to remain unchanged
    next_tensordict = self.transform._step(tensordict, next_tensordict)
    return next_tensordict

when next_preset is not None.

Reason and Possible fixes

It seems like a typo in _td.py.

def _exclude(self, ...):
    ...
    if keys_to_exclude is not None:
        for key, cur_keys in keys_to_exclude.items():
            val = _tensordict.get(key, None)
            if val is not None:
                val = val._exclude(
                    *cur_keys, inplace=inplace, set_shared=set_shared
                )
            # missing indent
            if not in place:
                _tensordict[key] = val # this would add the key to the result

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@btx0424 btx0424 added the bug Something isn't working label Jan 24, 2024
@vmoens
Copy link
Contributor

vmoens commented Jan 24, 2024

Should be easy to fix, let me patch this!

@vmoens vmoens linked a pull request Jan 24, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants