Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Unify Torch based Trainers on the TorchTrainer API #37

Merged
merged 11 commits into from
Jul 31, 2023

Conversation

matthewdeng
Copy link
Contributor

This REP proposes to remove the LightningTrainer, TransformersTrainer, and AccelerateTrainer APIs and unify them on the TorchTrainer API.

reps/2023-07-20-torch-trainer-apis.md Show resolved Hide resolved
reps/2023-07-20-torch-trainer-apis.md Show resolved Hide resolved
reps/2023-07-20-torch-trainer-apis.md Outdated Show resolved Hide resolved
Copy link
Contributor

@ericl ericl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about let's cover the proposed changes to TorchCheckpoint (and Checkpoint<>Trainer a bit in general) as well in this REP?

In addition, as you brought up we can also add examples of providing datasets=... in this REP for completeness.

Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much in favor for this change

reps/2023-07-20-torch-trainer-apis.md Outdated Show resolved Hide resolved
reps/2023-07-20-torch-trainer-apis.md Outdated Show resolved Hide resolved
3. The user cannot see the internal training loop, which makes it hard to implement and debug their code.
2. The `LightningTrainer` and `TransformersTrainer` APIs are opionated and may not allow the user to fully express their desired training logic (e.g. validation and testing).

This proposal explores the idea of centralizing on a single `TorchTrainer` as the single way running training code for PyTorch-based frameworks in a distributed fashion.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also want to call out that we do this for TensorflowTrainer (Native TF vs. Keras) already (though the surface area is smaller with only 2 APIs).

Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
callbacks=[checkpoint_callback],
devices=ray.train.lightning.get_devices(),
strategy=ray.train.lightning.RayDDPStrategy(),
plugins=[ray.train.lightning.RayEnvironment()],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
plugins=[ray.train.lightning.RayEnvironment()],
plugins=[ray.train.lightning.RayEnvPlugin()],

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is a subclass of the LightningEnvironment class which is a Plugin.

What about RayLightningEnvironment?

Copy link
Contributor

@ericl ericl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to move forward with broader review/vote.

Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

Copy link
Member

@woshiyyya woshiyyya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let's make it happen

Copy link

@YiranJing YiranJing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vote for this proposal!

That's great! Migrating a normal PyTorch Lightning model to Ray Trainer will be much easier then 🎉 , and it will require less effort for users familiar with PyTorch Lightning but new to Ray.

Comment on lines +367 to +368
eval_dataset = ray.train.get_dataset_shard("eval")
eval_dataset = RayDataIterableDataset(val_dataset)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eval_dataset = ray.train.get_dataset_shard("eval")
eval_dataset = RayDataIterableDataset(val_dataset)
eval_dataset_shard = ray.train.get_dataset_shard("eval")
eval_dataset = RayDataIterableDataset(val_dataset_shard)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah in this example I wanted to intentionally show that only train is being sharded (since there is no DataConfig specified).

Copy link

@YiranJing YiranJing Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get it! the function name (ray.train.get_dataset_shard("eval")) is a bit confusing -> makes me feel the example is also sharding the validation dataset -> this issue should be able to addressed in ray-project/ray#37668


train_dataset = ray.data.read_parquet(...).map_batches(...)
eval_dataset = ray.data.read_parquet(...).map_batches(...)
trainer = ray.train.torch.TorchTrainer(train_func, datasets={"train": train_dataset, "eval": eval_dataset})

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we require the use of DataConfig for correctly validation data sharding, e.g.:

trainer = ray.train.torch.TorchTrainer(
   train_func, 
   datasets={"train": train_dataset, "eval": eval_dataset},
   data_config=DataConfig(datasets_to_split=["train", "eval"]) # required
)

@woshiyyya I agree it would be good and result in fewer bugs if users could easily perform dataset sharding without the need for data_config.

Copy link
Member

@woshiyyya woshiyyya Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @YiranJing , thanks for the feedback! I agree that sharding validation datasets by default make a lot of sense.

But we may still need to keep data_config, but we can slightly change the default behavior so that most users don't have to provide an extra DataConfig in TorchTrainer.

If some users do have special use case, and want to have a unsharded dataset on each worker, they can then provide a DataConfig. e.g.

trainer = ray.train.torch.TorchTrainer(
   train_func, 
   datasets={"train": train_dataset, "eval": eval_dataset, "special": no_split_dataset}
   data_config=DataConfig(datasets_no_split=["special"])
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YiranJing could you share more about your use case and desired behavior in ray-project/ray#37668?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhe-thoughts zhe-thoughts merged commit 6ba5e26 into main Jul 31, 2023
@zhe-thoughts
Copy link
Collaborator

Great work and thanks @matthewdeng

@jjyao jjyao deleted the torch-trainer-apis branch October 2, 2023 20:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants