Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Max Value Writer #1622

Merged
merged 22 commits into from
Oct 18, 2023
Merged

[Feature] Max Value Writer #1622

merged 22 commits into from
Oct 18, 2023

Conversation

albertbou92
Copy link
Contributor

@albertbou92 albertbou92 commented Oct 10, 2023

Description

A Writer class for composable replay buffers that keeps the top elements based on some ranking key.

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 10, 2023
@albertbou92 albertbou92 changed the title [Feature] Max Value Writer [Feature, WIP] Max Value Writer Oct 10, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks promising thanks a lot for this
Can we add it to the docs?

torchrl/data/replay_buffers/writers.py Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
Comment on lines 155 to 156
for sample in data:
self.add(sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this efficient when data is a TensorDict?
Shouldn't we batch these ops?

Copy link
Contributor Author

@albertbou92 albertbou92 Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? if we need to check whether or not the values are higher that the current stored values we need to check them one by one.

maybe we can pre-sort the td values or something and try to make it a bit more efficient

@vmoens vmoens added the enhancement New feature or request label Oct 10, 2023
@albertbou92 albertbou92 changed the title [Feature, WIP] Max Value Writer [Feature] Max Value Writer Oct 17, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for this! I left some suggestions

torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
rank_data = data.get("_data").get(self._rank_key)

# Sum the rank key, in case it is a whole trajectory
rank_data = rank_data.sum().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe?
Maybe we should document what are the expected shapes for this class, eg

[B, T]

but not

[B1, B2, T]

Another option is to check the number of dimensions of the ranking key OR the name of the last dim of the input tensordict (which should be "time").

Not raising any exception and just doing a plain sum could lead to surprising results I think

Copy link
Contributor Author

@albertbou92 albertbou92 Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the first option. Since the ranking value has to be a single float we only allow data of the shape [] and [T] for the add method and [B] and [B, T] for the extend method. If data has a time dimension, we sum along it. If too many dimensions are provided, an error is raised.

I did not go for checking the dimension names because it seemed to restrictive. I don't think time dimension is always labelled

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not always but mostly
if you get your data from env.rollout or collector, it will.
If from there you store the data in a rb, it will keep the tag.
But if you reshape or do other stuff it could go away.

torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks so much!

torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/writers.py Outdated Show resolved Hide resolved
@vmoens vmoens merged commit 55d667e into pytorch:main Oct 18, 2023
52 of 59 checks passed
@vmoens vmoens deleted the max_val_writer branch October 18, 2023 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants