Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Support tuple values in TensorDictModule in_keys arguments. #1099

Open
1 task done
bachdj-px opened this issue Nov 20, 2024 · 0 comments
Open
1 task done
Assignees
Labels
enhancement New feature or request

Comments

@bachdj-px
Copy link

bachdj-px commented Nov 20, 2024

Motivation

One limitation that is found when passing a dictionary as in_keys to the current TensorDictModule is that a specific key in the input TensorDict cannot be used more than once in the wrapped function (or wrapped Module's forward).

For instance, one can do:

module = TensorDictModule(lambda x, *, y: x+y, in_keys={'1': 'x', '2': 'y'}, out_keys=['z'])
module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
print(td['z']) # tensor(3.)

but the following will crash as expected:

module = TensorDictModule(lambda x, *, y, t : x+y+t, in_keys={'1': 'x', '2': ('y', 't')}, out_keys=['z'])
module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
# TypeError: keywords must be strings

I would clearly see cases where passing the same input value to 2 different arguments in the wrapped function/module would be desirable.

Could you clarify whether that is a specific design choice or whether there is any reason this feature would cause a problem?

Solution

A simple solution would be to allow tuples and lists a values for the in_keys argument and parse accordingly. This can be achieved by iterating over said values and building the lists of in_keys and keywords:

if isinstance(in_keys, dict):
            # write the kwargs and create a list instead
            _in_keys = []
            self._kwargs = []
            for key, value in in_keys.items():
                self._kwargs.append(value)
                _in_keys.append(key)
                if isinstance(value, tuple) or isinstance(value, list):
                    for _value in value:
                        self._kwargs.append(_value)
                        _in_keys.append(key)
                else:
                    self._kwargs.append(value)
                    _in_keys.append(key)
            in_keys = _in_keys

In our second example, this would result in:

in_keys = ["1", "2", "2"]
self._kwargs = ["x", "y", "t"]

Additional context

I haven't identified undesired side effects that would come with this new support of various in_keys.

Checklist

  • I have checked that there is no similar issue in the repo (required)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants