-
Notifications
You must be signed in to change notification settings - Fork 319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Max Value Writer #1622
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks promising thanks a lot for this
Can we add it to the docs?
for sample in data: | ||
self.add(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this efficient when data is a TensorDict?
Shouldn't we batch these ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How? if we need to check whether or not the values are higher that the current stored values we need to check them one by one.
maybe we can pre-sort the td values or something and try to make it a bit more efficient
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for this! I left some suggestions
rank_data = data.get("_data").get(self._rank_key) | ||
|
||
# Sum the rank key, in case it is a whole trajectory | ||
rank_data = rank_data.sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this safe?
Maybe we should document what are the expected shapes for this class, eg
[B, T]
but not
[B1, B2, T]
Another option is to check the number of dimensions of the ranking key OR the name of the last dim of the input tensordict (which should be "time").
Not raising any exception and just doing a plain sum could lead to surprising results I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the first option. Since the ranking value has to be a single float we only allow data of the shape [] and [T] for the add
method and [B] and [B, T] for the extend
method. If data has a time dimension, we sum along it. If too many dimensions are provided, an error is raised.
I did not go for checking the dimension names because it seemed to restrictive. I don't think time dimension is always labelled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not always but mostly
if you get your data from env.rollout or collector, it will.
If from there you store the data in a rb, it will keep the tag.
But if you reshape or do other stuff it could go away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks so much!
Description
A Writer class for composable replay buffers that keeps the top elements based on some ranking key.
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!