-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Unify Torch based Trainers on the TorchTrainer API #37
Conversation
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about let's cover the proposed changes to TorchCheckpoint (and Checkpoint<>Trainer a bit in general) as well in this REP?
In addition, as you brought up we can also add examples of providing datasets=... in this REP for completeness.
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much in favor for this change
3. The user cannot see the internal training loop, which makes it hard to implement and debug their code. | ||
2. The `LightningTrainer` and `TransformersTrainer` APIs are opionated and may not allow the user to fully express their desired training logic (e.g. validation and testing). | ||
|
||
This proposal explores the idea of centralizing on a single `TorchTrainer` as the single way running training code for PyTorch-based frameworks in a distributed fashion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also want to call out that we do this for TensorflowTrainer (Native TF vs. Keras) already (though the surface area is smaller with only 2 APIs).
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
callbacks=[checkpoint_callback], | ||
devices=ray.train.lightning.get_devices(), | ||
strategy=ray.train.lightning.RayDDPStrategy(), | ||
plugins=[ray.train.lightning.RayEnvironment()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plugins=[ray.train.lightning.RayEnvironment()], | |
plugins=[ray.train.lightning.RayEnvPlugin()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm this is a subclass of the LightningEnvironment
class which is a Plugin.
What about RayLightningEnvironment
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to move forward with broader review/vote.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Let's make it happen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vote for this proposal!
That's great! Migrating a normal PyTorch Lightning model to Ray Trainer will be much easier then 🎉 , and it will require less effort for users familiar with PyTorch Lightning but new to Ray.
eval_dataset = ray.train.get_dataset_shard("eval") | ||
eval_dataset = RayDataIterableDataset(val_dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eval_dataset = ray.train.get_dataset_shard("eval") | |
eval_dataset = RayDataIterableDataset(val_dataset) | |
eval_dataset_shard = ray.train.get_dataset_shard("eval") | |
eval_dataset = RayDataIterableDataset(val_dataset_shard) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah in this example I wanted to intentionally show that only train
is being sharded (since there is no DataConfig
specified).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get it! the function name (ray.train.get_dataset_shard("eval")
) is a bit confusing -> makes me feel the example is also sharding the validation dataset -> this issue should be able to addressed in ray-project/ray#37668
|
||
train_dataset = ray.data.read_parquet(...).map_batches(...) | ||
eval_dataset = ray.data.read_parquet(...).map_batches(...) | ||
trainer = ray.train.torch.TorchTrainer(train_func, datasets={"train": train_dataset, "eval": eval_dataset}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, we require the use of DataConfig
for correctly validation data sharding, e.g.:
trainer = ray.train.torch.TorchTrainer(
train_func,
datasets={"train": train_dataset, "eval": eval_dataset},
data_config=DataConfig(datasets_to_split=["train", "eval"]) # required
)
@woshiyyya I agree it would be good and result in fewer bugs if users could easily perform dataset sharding without the need for data_config
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @YiranJing , thanks for the feedback! I agree that sharding validation datasets by default make a lot of sense.
But we may still need to keep data_config
, but we can slightly change the default behavior so that most users don't have to provide an extra DataConfig
in TorchTrainer
.
If some users do have special use case, and want to have a unsharded dataset on each worker, they can then provide a DataConfig
. e.g.
trainer = ray.train.torch.TorchTrainer(
train_func,
datasets={"train": train_dataset, "eval": eval_dataset, "special": no_split_dataset}
data_config=DataConfig(datasets_no_split=["special"])
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@YiranJing could you share more about your use case and desired behavior in ray-project/ray#37668?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replied here ray-project/ray#37668 (comment)
Great work and thanks @matthewdeng |
This REP proposes to remove the
LightningTrainer
,TransformersTrainer
, andAccelerateTrainer
APIs and unify them on theTorchTrainer
API.