Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Lightning integration example #2057

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

svnv-svsv-jm
Copy link

Description

This PR offers a convenient lightning.pytorch.LightningModule base class, from which one can inherit to be able to train a torchrl model using lightning.

Motivation and Context

This PR is inspired by this issue: Lightning-Universe/lightning-bolts#986

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Apr 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2057

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @svnv-svsv-jm!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 4, 2024
vmoens
vmoens previously approved these changes Apr 4, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty impressive work! It'll take me some time to review it.
Cc'ing people who may be interested
@giadefa @albertbou92 @BY571 @tchaton

@vmoens vmoens dismissed their stale review April 5, 2024 05:19

Approved by mistake

@vmoens vmoens changed the title Lightning integration example [Feature] Lightning integration example Apr 7, 2024
@vmoens vmoens added the enhancement New feature or request label Apr 7, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaning towards accepting a lightning backend for the trainers but given the limited bandwidth on my time I won't have much time to scale the PPO trainer to others.

A recipe for other trainers would probably help others code their own trainer. For instance a tutorial (in a second PR)?
Are you considering adding other trainers?

Since this is a new "trainer" I think it should be moved to torchrl/trainers and we will need to make the 2 APIs somewhat compatible.

An example in torchrl/examples would be welcome!

I didn't do a very "in-depth" review, just pieces that I could easily spot for a more homogeneous formatting within the lib!

__all__ = ["BaseRL"]

import typing as ty
from loguru import logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless this is packed with lightning, we won't be using it.
torchrl has a logger under torchrl._utils

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, forgot to remove loguru, which is used in the project I copied most of the code from. I will check out torchrl's logger or remove logging entirely for this.

@@ -0,0 +1,205 @@
"""Creates a helper class for more complex models."""

__all__ = ["BaseRL"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't use __all__ but import relevant classes in __init__.py

@@ -0,0 +1,205 @@
"""Creates a helper class for more complex models."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing headers

@@ -0,0 +1,33 @@
__all__ = ["find_device"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing headers
we don't use __all__

torchrl/lightning/accelerators.py Outdated Show resolved Hide resolved
this will never be called by the `pl.Trainer` in the `on_train_epoch_end` hook.
We have to call it manually in the `training_step`."""
scheduler = self.lr_schedulers()
assert isinstance(scheduler, LRScheduler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no assert in the codebase

batch_idx: int = 0,
tag: str = "train",
) -> Tensor:
"""Common step."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we expand this a bit?

loss = loss + value
loss_dict[f"{key}/{tag}"] = value
# Sanity check and return
assert isinstance(loss, torch.Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no assert in codebase

Comment on lines 1 to 3
"""Template for a PPO model on the pendulum env, for Lightning."""

__all__ = ["PPOPendulum"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing header, no __all__

import typing as ty

import torch
from tensordict import TensorDict # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the

# type: ignore

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mypy in my VSCode was complaining... So I mypy-ignored the line. This is not necessary in the codebase.

@svnv-svsv-jm
Copy link
Author

I will take care of the comments, and add an example under torchrl/examples, but I will leave moving the code to torchrl/trainers for later.

@vmoens
Copy link
Contributor

vmoens commented Jun 12, 2024

@svnv-svsv-jm I think that just moving it to trainers without change of API would already be a good thing!

@svnv-svsv-jm
Copy link
Author

@svnv-svsv-jm I think that just moving it to trainers without change of API would already be a good thing!

I moved it by allowing from torchrl.trainers import RLTrainingLoop, even though the actual source code is still under torchrl.lightning.

@vmoens
Copy link
Contributor

vmoens commented Jun 19, 2024

Not sure that cuts it unfortunately. The long term goal is that there is no torchrl.lightning so if we merge this as-if deprecating it will be be hard on the user side

@svnv-svsv-jm
Copy link
Author

@vmoens No problem! Just wanted to make sure I understood correctly. So indeed I will just move torch.lightning to torchrl.trainers entirely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants