-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Hindsight Experience Replay Transform #1819
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1819
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 19 Unrelated FailuresAs of commit 90eef75 with merge base 57139bd (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
augmentation_td = TensorDict( | ||
{ | ||
"observation": sampled_td.get("observation").repeat_interleave( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we keep it a transform we probably need to specify all those tensordict keys ... Not sure what a better alternative would be. Any idea?
|
||
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
augmentation_td = self.her_augmentation(tensordict) | ||
return torch.cat([tensordict, augmentation_td], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As explained above. It doesnt feel like a transform
as we create a new tensordict and have to combine original and augmented data before adding to the replay buffer. I think ideally the "augmentations" would done directly after the collection. So as a postproc for collectors or as here in the example as an inverse_transform for the replay buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not a transform? I think it's pretty neat to use a transform. Who said a transform had to change things in-place?
Our API to modify samples at writing time is to use either a transform or a different writer. If you think this can be achieved with a writer I'm on board. But I don't think there's anything wrong with the transform.
An advantage of using a writer instead os that it feels more natural (transforms can be used with envs unless specified otherwise, writers are dedicated to RBs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ahmed-touati suggested we use a sampler for this rather than a transform. I'm not strongly opinionated on the matter, mostly because I need more context on what we're trying to achieve here.
Can you elaborate a bit more on what this transform does, maybe with a bunch of examples?
So HER is mainly used in goal-conditioned RL with sparse reward signals where the agent has to reach/achieve a goal state and only gets a reward (+1) when the goal state is achieved, otherwise no reward. The observation consists of three elements: the observation the agent sees, the state the agent had (could be x,y,z position), and the goal state the agent should reach (x,y,z). A typical task could be a robot that has to reach a goal position. The observation will include the agent position but its mostly added as additional information also helps here for understanding. Now as we have a sparse reward function most of the trajectories will have no learning signal for the agent as it might not be possible for the agent to reach the goal position randomly or by pure luck. So lets say you have a real transition (obs, action, reward, done, next obs, achieved_position, goal_position) for this tuple you now want to sample a new goal_position and then calculate the reward based on this new goal_position and the real achieved_position. So you then add the real transition (obs, action, reward, done, next obs, achieved_position, goal_position) but also the HER augmented transition (obs, action, new_reward, done, new next obs, achieved_position, new_goal_position). The sampling can happen in different ways but is not important for now. However, I think important will be that we need the reward function, Im not sure if we can pass it to the writer/sampler for the buffer, that's why my first thought was a transform. Most of the time the reward function might just be Euclidean distance but maybe for other tasks the user needs to provide a more sophisticated reward function. |
Why not? I would guess that even if it's a complex nn.Module you can still do pretty much everything with a well tailored function (at least nothing less than with a transform). |
Thanks for the context btw! |
Revisiting this I think it would make much more sense to do it with a writer. We want to augment current incoming data with new sampled goal states and store them all together in the buffer. I think this would be generally a good way to add other data augmentation strategies with writer instead of transforms. Having a closer look right now on the writer classes and will update the code here |
But this would not allow us to stack multiple augmentations on top of each other... so maybe not that ideal for augmentations |
You could still transform your data before passing it to the writer, but not after |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While hindsight experience replay is pretty useful, I think it falls under the category of specialized algorithm rather than a building block @vmoens
new_goals.append(splitted_achieved_goals[i][ids]) | ||
|
||
# calculate rewards given new desired goals and old achieved goals | ||
vmap_rewards = torch.vmap(distance_reward_function) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you wanna call self.reward_function
instead of distance_reward_function
. Also maybe the reward_function
should be a TensorDictModule such it can be more easily customized for a given environment. There is torchrl.modules.VmapModule
for wrapping TensorDictModules with vmap.
cat_rewards = torch.cat(rewards).reshape(b, t, self.samples, -1).squeeze(-1) | ||
cat_new_goals = torch.cat(new_goals).reshape(b, t, self.samples, -1) | ||
|
||
augmentation_td = TensorDict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the augmentation_td
should still maintain other metadata that are related to the state rather than selecting only the keys: observation, action, terminated, truncated, ...
Not sure about that one. |
Description
Adds Hindsight Experience Replay (HER) Transform
Motivation and Context
The first draft for the HER transform. However, I am not sure if it should be a
Transform
or if we create an extraAugmentation
class as we are not transforming a single element in the tensordict but augmenting existing collection data. Could be interesting for future "data augmentation strategies", which I think we do not have until now.Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!