diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 507cb7936f7b3..471e50a0dd8b6 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424)) +- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) + + ### Changed - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 1dee4d3d3b7ba..5998cd638b556 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from datetime import timedelta from typing import Any, Dict, Generator, List, Literal, Optional, Union @@ -113,7 +113,12 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" - return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs) + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext() + with ctx: + return DistributedDataParallel( + module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs + ) def module_to_device(self, module: Module) -> None: module.to(self.root_device) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fc3bf6ba9b4a0..c33e4efbf8bb9 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Update `LightningDataModule.from_datasets` to support arbitrary iterables ([#17402](https://github.com/Lightning-AI/lightning/pull/17402)) +- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) @@ -53,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Increased the minimum XLA requirement to 1.13 ([#17368](https://github.com/Lightning-AI/lightning/pull/17368)) + +- `self.log`ed tensors are now kept in the original device to reduce unnecessary host-to-device synchronizations ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) + ### Deprecated - diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 6a580aeafbed5..c150206772e43 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -607,11 +607,7 @@ def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: - value = ( - value.clone().detach().to(self.device) - if isinstance(value, Tensor) - else torch.tensor(value, device=self.device) - ) + value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device) if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 7903a5c4597f3..14f39e5c0fa08 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -309,9 +309,6 @@ def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_start`` hooks.""" trainer = self.trainer - assert self._results is not None - self._results.to(device=trainer.lightning_module.device) - hook_name = "on_test_start" if trainer.testing else "on_validation_start" call._call_callback_hooks(trainer, hook_name, *args, **kwargs) call._call_lightning_module_hook(trainer, hook_name, *args, **kwargs) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index c317e92840abe..16c11ba45c677 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -309,8 +309,6 @@ def on_run_start(self) -> None: self._data_fetcher = _select_data_fetcher(trainer) - self._results.to(device=trainer.lightning_module.device) - call._call_callback_hooks(trainer, "on_train_start") call._call_lightning_module_hook(trainer, "on_train_start") call._call_strategy_hook(trainer, "on_train_start") diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 01fa3b25264dd..8ac5b16c6ef61 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from contextlib import nullcontext from datetime import timedelta from typing import Any, Callable, Dict, List, Literal, Optional, Union @@ -182,7 +183,10 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") - return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext() + with ctx: + return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) def setup_distributed(self) -> None: log.debug(f"{self.__class__.__name__}: setting up distributed...") diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index b023d0ef39d38..d7446d1d6f382 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -23,7 +23,6 @@ from lightning.fabric.utilities import move_data_to_device from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars -from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch.utilities.data import extract_batch_size from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -180,7 +179,7 @@ def is_custom_reduction(self) -> bool: return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction or self.is_sum_reduction) -class _ResultMetric(Metric, _DeviceDtypeModuleMixin): # type: ignore[misc] # torchmetrics methods should return Self +class _ResultMetric(Metric): """Wraps the value provided to `:meth:`~lightning.pytorch.core.module.LightningModule.log`""" def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: @@ -314,10 +313,9 @@ class _ResultCollection(dict): DATALOADER_SUFFIX = "/dataloader_idx_{}" - def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None: + def __init__(self, training: bool) -> None: super().__init__() self.training = training - self.device: Optional[Union[str, torch.device]] = device self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.dataloader_idx: Optional[int] = None @@ -410,13 +408,13 @@ def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None: Value can be provided as a nested collection """ - metric = _ResultMetric(meta, isinstance(value, Tensor)).to(self.device) + metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device) self[key] = metric def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: result_metric = self[key] # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` - result_metric.forward(value.to(self.device), batch_size) + result_metric.forward(value, batch_size) result_metric.has_reset = False @staticmethod @@ -509,9 +507,6 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection": """Move all data to the given device.""" self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs)) - - if "device" in kwargs: - self.device = kwargs["device"] return self def cpu(self) -> "_ResultCollection": @@ -524,4 +519,4 @@ def __str__(self) -> str: return f"{self.__class__.__name__}({self_str})" def __repr__(self) -> str: - return f"{{{self.training}, {repr(self.device)}, {super().__repr__()}}}" + return f"{{{self.training}, {super().__repr__()}}}" diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 2cf8a146326df..dbdc03bfbc959 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -67,7 +67,7 @@ def result_reduce_ddp_fn(strategy): metric_b = metric_b.to(f"cuda:{rank}") metric_c = metric_c.to(f"cuda:{rank}") - result = _ResultCollection(True, torch.device(f"cuda:{rank}")) + result = _ResultCollection(True) for _ in range(3): cumulative_sum = 0 @@ -107,7 +107,7 @@ def test_result_metric_integration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = _ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True) for _ in range(3): cumulative_sum = 0 @@ -148,7 +148,6 @@ def test_result_metric_integration(): assert repr(result) == ( "{" "True, " - "device(type='cpu'), " "{'h.a': _ResultMetric('a', value=DummyMetric()), " "'h.b': _ResultMetric('b', value=DummyMetric()), " "'h.c': _ResultMetric('c', value=DummyMetric())" @@ -157,7 +156,7 @@ def test_result_metric_integration(): def test_result_collection_simple_loop(): - result = _ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True) current_fx_name = None batch_idx = None @@ -205,7 +204,7 @@ def my_sync_dist(x, *_, **__): def test_result_collection_restoration(tmpdir): """This test make sure metrics are properly reloaded on failure.""" - result = _ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 4a2f6a31bd6ad..2b050d0b06eea 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -204,9 +204,9 @@ def on_validation_epoch_end(self) -> None: assert set(trainer.callback_metrics) == {"val_loss", "val_loss_epoch"} # make sure values are correct - assert trainer.logged_metrics["val_loss_epoch"] == model.manual_epoch_end_mean - assert trainer.callback_metrics["val_loss_epoch"] == model.manual_epoch_end_mean - assert trainer.callback_metrics["val_loss"] == model.manual_epoch_end_mean + assert torch.allclose(trainer.logged_metrics["val_loss_epoch"], model.manual_epoch_end_mean) + assert torch.allclose(trainer.callback_metrics["val_loss_epoch"], model.manual_epoch_end_mean) + assert torch.allclose(trainer.callback_metrics["val_loss"], model.manual_epoch_end_mean) assert trainer.logged_metrics["val_loss_step"] == model.val_losses[-1] diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index 4796552652554..4a9ff8b5eab7f 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -605,7 +605,7 @@ def test_result_collection_batch_size_extraction(): fx_name = "training_step" log_val = torch.tensor(7.0) - results = _ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True) results.batch = torch.randn(1, 4) train_mse = MeanSquaredError() train_mse(torch.randn(4, 5), torch.randn(4, 5)) @@ -615,7 +615,7 @@ def test_result_collection_batch_size_extraction(): assert isinstance(results["training_step.mse"].value, MeanSquaredError) assert results["training_step.log_val"].value == log_val - results = _ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True) results.batch = torch.randn(1, 4) results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True) assert results.batch_size == 1 @@ -624,7 +624,7 @@ def test_result_collection_batch_size_extraction(): def test_result_collection_no_batch_size_extraction(): - results = _ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True) results.batch = torch.randn(1, 4) fx_name = "training_step" batch_size = 10