Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch for device placement (Reduce host and device syncs) #17334

Merged
merged 21 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))


- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))


### Changed

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from typing import Any, Dict, Generator, List, Literal, Optional, Union

Expand Down Expand Up @@ -113,7 +113,12 @@ def setup_environment(self) -> None:

def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs)
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext()
with ctx:
return DistributedDataParallel(
module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs
)

def module_to_device(self, module: Module) -> None:
module.to(self.root_device)
Expand Down
6 changes: 6 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Update `LightningDataModule.from_datasets` to support arbitrary iterables ([#17402](https://github.com/Lightning-AI/lightning/pull/17402))


- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))


### Changed

- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand All @@ -53,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Increased the minimum XLA requirement to 1.13 ([#17368](https://github.com/Lightning-AI/lightning/pull/17368))


- `self.log`ed tensors are now kept in the original device to reduce unnecessary host-to-device synchronizations ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))

### Deprecated

-
Expand Down
6 changes: 1 addition & 5 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,7 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = (
value.clone().detach().to(self.device)
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device)
)
value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device)
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
Expand Down
3 changes: 0 additions & 3 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,6 @@ def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
trainer = self.trainer

assert self._results is not None
self._results.to(device=trainer.lightning_module.device)

hook_name = "on_test_start" if trainer.testing else "on_validation_start"
call._call_callback_hooks(trainer, hook_name, *args, **kwargs)
call._call_lightning_module_hook(trainer, hook_name, *args, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,6 @@ def on_run_start(self) -> None:

self._data_fetcher = _select_data_fetcher(trainer)

self._results.to(device=trainer.lightning_module.device)

call._call_callback_hooks(trainer, "on_train_start")
call._call_lightning_module_hook(trainer, "on_train_start")
call._call_strategy_hook(trainer, "on_train_start")
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import nullcontext
from datetime import timedelta
from typing import Any, Callable, Dict, List, Literal, Optional, Union

Expand Down Expand Up @@ -182,7 +183,10 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext()
Comment on lines +186 to +187
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we avoid using the stream if not running on CUDA? Even if there are no known side effects, for sanity I would prefer that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, yes. We can check the root device type instead

with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

def setup_distributed(self) -> None:
log.debug(f"{self.__class__.__name__}: setting up distributed...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch.utilities.data import extract_batch_size
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -180,7 +179,7 @@ def is_custom_reduction(self) -> bool:
return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction or self.is_sum_reduction)


class _ResultMetric(Metric, _DeviceDtypeModuleMixin): # type: ignore[misc] # torchmetrics methods should return Self
class _ResultMetric(Metric):
"""Wraps the value provided to `:meth:`~lightning.pytorch.core.module.LightningModule.log`"""

def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
Expand Down Expand Up @@ -314,10 +313,9 @@ class _ResultCollection(dict):

DATALOADER_SUFFIX = "/dataloader_idx_{}"

def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
def __init__(self, training: bool) -> None:
super().__init__()
self.training = training
self.device: Optional[Union[str, torch.device]] = device
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.dataloader_idx: Optional[int] = None
Expand Down Expand Up @@ -410,13 +408,13 @@ def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:

Value can be provided as a nested collection
"""
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(self.device)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
self[key] = metric

def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(value.to(self.device), batch_size)
result_metric.forward(value, batch_size)
result_metric.has_reset = False

@staticmethod
Expand Down Expand Up @@ -509,9 +507,6 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non
def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
"""Move all data to the given device."""
self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))

if "device" in kwargs:
self.device = kwargs["device"]
return self

def cpu(self) -> "_ResultCollection":
Expand All @@ -524,4 +519,4 @@ def __str__(self) -> str:
return f"{self.__class__.__name__}({self_str})"

def __repr__(self) -> str:
return f"{{{self.training}, {repr(self.device)}, {super().__repr__()}}}"
return f"{{{self.training}, {super().__repr__()}}}"
9 changes: 4 additions & 5 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def result_reduce_ddp_fn(strategy):
metric_b = metric_b.to(f"cuda:{rank}")
metric_c = metric_c.to(f"cuda:{rank}")

result = _ResultCollection(True, torch.device(f"cuda:{rank}"))
result = _ResultCollection(True)

for _ in range(3):
cumulative_sum = 0
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_result_metric_integration():
metric_b = DummyMetric()
metric_c = DummyMetric()

result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)

for _ in range(3):
cumulative_sum = 0
Expand Down Expand Up @@ -148,7 +148,6 @@ def test_result_metric_integration():
assert repr(result) == (
"{"
"True, "
"device(type='cpu'), "
"{'h.a': _ResultMetric('a', value=DummyMetric()), "
"'h.b': _ResultMetric('b', value=DummyMetric()), "
"'h.c': _ResultMetric('c', value=DummyMetric())"
Expand All @@ -157,7 +156,7 @@ def test_result_metric_integration():


def test_result_collection_simple_loop():
result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)
current_fx_name = None
batch_idx = None

Expand Down Expand Up @@ -205,7 +204,7 @@ def my_sync_dist(x, *_, **__):
def test_result_collection_restoration(tmpdir):
"""This test make sure metrics are properly reloaded on failure."""

result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ def on_validation_epoch_end(self) -> None:
assert set(trainer.callback_metrics) == {"val_loss", "val_loss_epoch"}

# make sure values are correct
assert trainer.logged_metrics["val_loss_epoch"] == model.manual_epoch_end_mean
assert trainer.callback_metrics["val_loss_epoch"] == model.manual_epoch_end_mean
assert trainer.callback_metrics["val_loss"] == model.manual_epoch_end_mean
assert torch.allclose(trainer.logged_metrics["val_loss_epoch"], model.manual_epoch_end_mean)
assert torch.allclose(trainer.callback_metrics["val_loss_epoch"], model.manual_epoch_end_mean)
assert torch.allclose(trainer.callback_metrics["val_loss"], model.manual_epoch_end_mean)
assert trainer.logged_metrics["val_loss_step"] == model.val_losses[-1]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def test_result_collection_batch_size_extraction():
fx_name = "training_step"
log_val = torch.tensor(7.0)

results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
train_mse = MeanSquaredError()
train_mse(torch.randn(4, 5), torch.randn(4, 5))
Expand All @@ -615,7 +615,7 @@ def test_result_collection_batch_size_extraction():
assert isinstance(results["training_step.mse"].value, MeanSquaredError)
assert results["training_step.log_val"].value == log_val

results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True)
assert results.batch_size == 1
Expand All @@ -624,7 +624,7 @@ def test_result_collection_batch_size_extraction():


def test_result_collection_no_batch_size_extraction():
results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
fx_name = "training_step"
batch_size = 10
Expand Down