Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Unify Torch based Trainers on the TorchTrainer API #37

Merged
merged 11 commits into from
Jul 31, 2023
153 changes: 121 additions & 32 deletions reps/2023-07-20-torch-trainer-apis.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ While this pattern makes it very obvious to the user that Ray Train provides an
1. There is overhead for the user to learn and use these APIs.
1. They are all different from each other.
2. They are different from their corresponding framework.
3. The user cannot see the internal training loop, which makes it hard to implement and debug their code.
2. The `LightningTrainer` and `TransformersTrainer` APIs are opionated and may not allow the user to fully express their desired training logic (e.g. validation and testing).

matthewdeng marked this conversation as resolved.
Show resolved Hide resolved
This proposal explores the idea of centralizing on a single `TorchTrainer` as the single way running training code for PyTorch-based frameworks in a distributed fashion.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also want to call out that we do this for TensorflowTrainer (Native TF vs. Keras) already (though the surface area is smaller with only 2 APIs).


This change is motivated by the following goals:
1. **Transparency:** Enables users to convert existing code to use Ray Train with minimal changes.
2. **Flexibility:** Allows users to easily leverage the full suite of functionality of their framework (e.g. `lightning.Trainer.test()`) or even convert their code between frameworks (e.g. Transformers → Accelerate).
2. **Flexibility:** Allows users to easily leverage the full suite of functionality of their framework (e.g. `lightning.Trainer.test()`, `transformers.Trainer.predict()`) or even convert their code between frameworks (e.g. Transformers → Accelerate).
3. **Simplicity:** Reduces the surface area of the Ray Train interface, lowering overhead for both users and developers.

### Should this change be within `ray` or outside?
Expand Down Expand Up @@ -57,14 +58,54 @@ trainer = TorchTrainer(train_func, ...)
trainer.fit()
```

**Reporting & Checkpointing:**

A critical part of this API change is ensuring that the user can report metrics as well as save and load checkpoints with their desired framework. In [[REP] Consolidated persistence API for Ray Train/Tune #35](https://github.com/ray-project/enhancements/pull/35), we introduce a simpler API for Checkpoints in which Ray Train will treat the contents of the Checkpoint as an opaque directory. As part of this current REP, we will show how we can utilize this new API along with the frameworks' native checkpointing APIs to fully support checkpointing, and remove the need for the existing `<Framework>Checkpoint` APIs:

- [TorchCheckpoint](https://docs.ray.io/en/releases-2.6.0/train/api/doc/ray.train.torch.TorchCheckpoint.html)
- [LightningCheckpoint](https://docs.ray.io/en/releases-2.6.0/train/api/doc/ray.train.lightning.LightningCheckpoint.html)
- [TransformersCheckpoint](https://docs.ray.io/en/releases-2.6.0/train/api/doc/ray.train.huggingface.transformers.TransformersCheckpoint.html)

To do this we rely on a few key utilty APIs.
```python
def train_func():
# Restore checkpoint.
checkpoint = ray.train.get_context().get_checkpoint()
...
# Create checkpoint.
checkpoint = ray.train.Checkpoint.from_directory(...)
ray.train.report(metrics, checkpoint)
```

As an example, we can modify a Torch training script that is run with the `TorchTrainer` as follows:
```python
CHECKPOINT_FILE_NAME = "model.pt"

def train_func():
# Restore checkpoint.
checkpoint = ray.train.get_context().get_checkpoint()
if checkpoint:
checkpoint_dir = checkpoint.to_directory()
checkpoint_path = Path(checkpoint_dir) / CHECKPOINT_FILE_NAME
checkpoint_data = torch.load(checkpoint_path)
...
# Create checkpoint.
temp_dir = tempfile.TemporaryDirectory()
checkpoint_path = Path(temp_dir) / CHECKPOINT_FILE_NAME
torch.save(checkpoint_data, checkpoint_path)
checkpoint = ray.train.Checkpoint.from_directory(temp_dir)
ray.train.report(metrics, checkpoint)
```

In the following sections, we show how this change will be reflected each of the individual frameworks by comparing:
1. A (minimal) training script for the framework.
2. The training script rewritten using the current Ray Train APIs.
3. The training script rewritten using the proposed Ray Train APIs.
1. Additionally, we show how to report metrics and handle checkpoints with Ray Train.

### Lightning

**Lightning:**
#### Lightning
```python
import lightning

Expand All @@ -75,7 +116,7 @@ trainer = lightning.Trainer(**trainer_kwargs)
trainer.fit(MyLightningModule(**module_kwargs), **fit_kwargs)
```

**`LightningTrainer` (Current):**
#### `LightningTrainer` (Current)

In the existing `LightningTrainer`, we expose a `LightningConfigBuilder` that allows the user to configure the `lightning.Trainer` in a way that is compatible with Ray Train. While similar to the native Lightning interface, this requires a non-trivial amount of rewriting of the user's training code and does not provide a strict 1-1 mapping.

Expand All @@ -102,46 +143,33 @@ trainer = LightningTrainer(
trainer.fit()
```

**`TorchTrainer` (Proposed):**
#### `TorchTrainer` (Proposed):

In this proposal, we provide a few common utilties that the user can use directly in their training code to configure the `Trainer` object.

- `get_devices` - Returns a list of devices to use for training.
- `prepare_trainer` - Validates that the `Trainer` object is configured correctly to be compatible with Ray Train.
- `prepare_trainer` - Validates and/or makes modifications so that the `Trainer` object is configured correctly to be compatible with Ray Train.
- `RayDDPStrategy`, `RayFSDPStrategy`, `RayDeepSpeedStrategy` - `LightningStrategy`s for different training strategies that are compatible with Ray Train.
- `RayEnvironment` - A `LightningEnvironment` that is compatible with Ray Train.
- `RayModelCheckpoint` - A `LightningCallback` that configures the `Trainer` to report metrics and checkpoints to Ray Train.

With this, the user can directly interact with the Lightning interface, and is able define their training logic to use additional functionality such as `lightning.Trainer.test()`.

```python
import lightning
import ray.train.lightning
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
get_devices,
prepare_trainer,
RayDDPStrategy,
RayEnvironment,
RayModelCheckpoint
)

def train_func(config):

...

devices = get_devices()
strategy = RayDDPStrategy()
environment = RayEnvironment()
checkpoint_callback = RayModelCheckpoint()

lightning_trainer = lightning.Trainer(
devices=devices,
strategy=strategy,
plugins=[environment],
callbacks=[checkpoint_callback],
devices=ray.train.lightning.get_devices(),
strategy=ray.train.lightning.RayDDPStrategy(),
plugins=[ray.train.lightning.RayEnvironment()],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
plugins=[ray.train.lightning.RayEnvironment()],
plugins=[ray.train.lightning.RayEnvPlugin()],

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is a subclass of the LightningEnvironment class which is a Plugin.

What about RayLightningEnvironment?

**trainer_kwargs
)
ray.train.lightning.prepare_trainer(lightning_trainer)
lightning_trainer = ray.train.lightning.prepare_trainer(lightning_trainer)
lightning_trainer.fit(MyLightningModule(**module_kwargs), **fit_kwargs)


Expand All @@ -152,9 +180,35 @@ trainer = TorchTrainer(
trainer.fit()
```

**Reporting & Checkpointing:**

For Lightning, we introduce a `RayTrainReportCallback`(`Callback`) that will report metrics and checkpoints to Ray Train.

```python
from lightning.callbacks.pytorch import Checkpoint

CHECKPOINT_FILE_NAME = "checkpoint.ckpt"

class RayTrainReportCallback(Checkpoint):
# Define logic for calling `ray.train.report` as a Callback.
...

def train_func():
# Create checkpoint.
report_checkpoint_callback = RayTrainReportCallback()
trainer = Trainer(..., callbacks=[report_checkpoint_callback])
# Restore checkpoint.
checkpoint_path = None
checkpoint = ray.train.get_context().get_checkpoint()
if checkpoint:
checkpoint_dir = checkpoint.to_directory()
checkpoint_path = Path(checkpoint_dir) / CHECKPOINT_FILE_NAME
trainer.fit(model, ckpt_path=checkpoint_path, ...)
```

### HuggingFace Transformers

**Transformers:**
#### Transformers
```python
import transformers

Expand Down Expand Up @@ -184,22 +238,22 @@ trainer = TransformersTrainer(
trainer.fit()
```

**`TorchTrainer` (Proposed):**
#### `TorchTrainer` (Proposed)

In this proposal, we provide a `prepare_trainer` utility that the user can use directly in their training loop code to validate that the `Trainer` object is configured correctly to be compatible with Ray Train.
In this proposal, we provide a `prepare_trainer` utility that the user can use directly in their training loop code to validate and/or make modifications so that the `Trainer` object is configured correctly to be compatible with Ray Train.

With this, the user is able define their training logic to use additional functionality such as `transformers.Trainer.evaluate()`.

```python
import transformers
import ray.train.huggingface.transformers
from ray.train.torch import TorchTrainer
from ray.train.huggingface.transformers import prepare_trainer

def train_func(config):
...
transformers_trainer = transformers.Trainer(**trainer_kwargs)
prepare_trainer(transformers_trainer)
trainer.train()
transformers_trainer = ray.train.huggingface.transformers.prepare_trainer(transformers_trainer)
transformers_trainer.train()

trainer = TorchTrainer(
train_loop_per_worker=train_func,
Expand All @@ -208,6 +262,32 @@ trainer = TorchTrainer(
trainer.fit()
```

**Reporting & Checkpointing:**

For Transformers, we introduce a `RayTrainReportCallback`(`TrainerCallback`) that will report metrics and checkpoints to Ray Train.

```python
from transformers.trainer_callback import TrainerCallback

class RayTrainReportCallback(TrainerCallback):
# Define logic for calling `ray.train.report` as a Callback.
...

def train_func():
...
trainer = Trainer(...)
# Create checkpoint.
report_checkpoint_callback = RayTrainReportCallback()
trainer.add_callback(report_checkpoint_callback)
# Restore checkpoint.
checkpoint_path = None
checkpoint = ray.train.get_context().get_checkpoint()
if checkpoint:
checkpoint_dir = checkpoint.to_directory()
checkpoint_path = Path(checkpoint_dir) / CHECKPOINT_FILE_NAME
trainer.train(resume_from_checkpoint=checkpoint_path)
```

### HuggingFace Accelerate

At its core, HuggingFace Accelerate can be used by configuring and instantiating an `Accelerator` object in your training code.
Expand All @@ -218,15 +298,15 @@ HuggingFace Accelerate also provides two **optional** CLI commands:

The current `AccelerateTrainer` provides functionality similar to `accelerate launch`, in which it will read the generated configuration file and apply it to the `Accelerator`. However, as Ray Train already provides a distributed launching mechanism with its own configurability, we find diminishing value in parsing the configuration file simply for configuring the `Accelerator`. Additionally, it has been a common misconception that the user _must_ use `AccelerateTrainer` in order to use Accelerate. As such, in this proposal we simplify the story by recommending that users configure the `Accelerator` directly in their `TorchTrainer` training code.

**Accelerate:**
#### Accelerate
```python
from accelerate import Accelerator

accelerator = Accelerator(**accelerator_kwargs)
...
```

**`AccelerateTrainer` (Current):**
#### `AccelerateTrainer` (Current)

```bash
! accelerate config
Expand All @@ -242,12 +322,13 @@ def train_func(config):

trainer = AccelerateTrainer(
train_loop_per_worker=train_func,
matthewdeng marked this conversation as resolved.
Show resolved Hide resolved
accelerate_config=accelerate_config,
...
)
trainer.fit()
```

**`TorchTrainer` (Proposed):**
#### `TorchTrainer` (Proposed)

In this proposal, we recommend configuring the `Accelerator` object directly in the training code, exactly as you would if you were not using Ray Train.

Expand All @@ -266,6 +347,11 @@ trainer = TorchTrainer(
trainer.fit()
```


**Reporting & Checkpointing:**

This is done in the same way as the aforementioned Torch example.

## Compatibility, Deprecation, and Migration Plan

Ray 2.7:
Expand All @@ -274,6 +360,9 @@ Ray 2.7:
- Mark `AccelerateTrainer`, `TransformersTrainer`, and `LightningTrainer` APIs as deprecated.

Ray 2.8:
- Raise error from `AccelerateTrainer`, `TransformersTrainer`, and `LightningTrainer` APIs.

Ray 2.9:
- Remove `AccelerateTrainer`, `TransformersTrainer`, and `LightningTrainer` APIs.

## Test Plan and Acceptance Criteria
Expand Down