From cf2d891ead5f6fe703354e944a72beec360efd5b Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:27:52 -0700 Subject: [PATCH 01/15] Update module.py --- src/lightning/pytorch/core/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index feef9eb87028b..49fffe3e8768d 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -609,9 +609,9 @@ def __check_allowed(v: Any, name: str, value: Any) -> None: def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: value = ( - value.clone().detach().to(self.device) + value.clone().detach() if isinstance(value, Tensor) - else torch.tensor(value, device=self.device) + else torch.tensor(value) ) if not torch.numel(value) == 1: raise ValueError( From a47fd2c051a40c0b26de39cc6f434d0ce99445bd Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:30:22 -0700 Subject: [PATCH 02/15] Update ddp.py --- src/lightning/pytorch/strategies/ddp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index a6899b1c13307..2e50c75bed0f7 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -177,7 +177,10 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") - return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) + with torch.cuda.stream(torch.cuda.Stream()): + ddp_model = DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) + return ddp_model + def setup_distributed(self) -> None: log.debug(f"{self.__class__.__name__}: setting up distributed...") From cd67bf5f576d41d75424a1f77b0fd516b75999f4 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:33:18 -0700 Subject: [PATCH 03/15] Update result.py --- .../pytorch/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index ce1e4f41b1116..6de8841f01f64 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -417,7 +417,7 @@ def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None: def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: result_metric = self[key] # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` - result_metric.forward(value.to(self.device), batch_size) + result_metric.forward(value, batch_size) result_metric.has_reset = False @staticmethod From 60af42944384b90a3e58d0c8ed914af54cbe1b05 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Tue, 11 Apr 2023 12:19:08 -0700 Subject: [PATCH 04/15] Update result.py --- .../pytorch/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 6de8841f01f64..2d858fa1537d4 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -411,7 +411,7 @@ def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None: Value can be provided as a nested collection """ - metric = _ResultMetric(meta, isinstance(value, Tensor)).to(self.device) + metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device) self[key] = metric def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: From 322a333e5f4e6a8eb98d6170025d08013a7a6342 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Apr 2023 19:20:37 +0000 Subject: [PATCH 05/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/core/module.py | 6 +----- src/lightning/pytorch/strategies/ddp.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 49fffe3e8768d..ad799f7f105db 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -608,11 +608,7 @@ def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: - value = ( - value.clone().detach() - if isinstance(value, Tensor) - else torch.tensor(value) - ) + value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value) if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 2e50c75bed0f7..1662e0a5f5cf1 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -181,7 +181,6 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: ddp_model = DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) return ddp_model - def setup_distributed(self) -> None: log.debug(f"{self.__class__.__name__}: setting up distributed...") reset_seed() From a0ce996ea525b41e4965e34aff7d9bf5b6b1160c Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Fri, 14 Apr 2023 15:48:05 -0700 Subject: [PATCH 06/15] Update module.py --- src/lightning/pytorch/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index ad799f7f105db..5785c9adf18a7 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -608,7 +608,7 @@ def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: - value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value) + value = (value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value)) if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." From 8e835211f968d5a259e96300d36d3336a2058cc1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Apr 2023 22:49:14 +0000 Subject: [PATCH 07/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 5785c9adf18a7..ad799f7f105db 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -608,7 +608,7 @@ def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: - value = (value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value)) + value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value) if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." From 0aa43cb047aa7c2dcc3487a32518ebf5a4a00b55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 15 Apr 2023 05:47:11 +0200 Subject: [PATCH 08/15] Remove device entirely from result --- src/lightning/pytorch/core/module.py | 2 +- .../trainer/connectors/logger_connector/result.py | 8 ++------ .../tests_pytorch/core/test_metric_result_integration.py | 9 ++++----- .../trainer/logging_/test_eval_loop_logging.py | 6 +++--- .../trainer/logging_/test_logger_connector.py | 6 +++--- 5 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index ad799f7f105db..67bb30939e7cf 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -608,7 +608,7 @@ def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: - value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value) + value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device) if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 2d858fa1537d4..c1208872d7184 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -315,10 +315,9 @@ class _ResultCollection(dict): DATALOADER_SUFFIX = "/dataloader_idx_{}" - def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None: + def __init__(self, training: bool) -> None: super().__init__() self.training = training - self.device: Optional[Union[str, torch.device]] = device self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.dataloader_idx: Optional[int] = None @@ -510,9 +509,6 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection": """Move all data to the given device.""" self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs)) - - if "device" in kwargs: - self.device = kwargs["device"] return self def cpu(self) -> "_ResultCollection": @@ -525,4 +521,4 @@ def __str__(self) -> str: return f"{self.__class__.__name__}({self_str})" def __repr__(self) -> str: - return f"{{{self.training}, {repr(self.device)}, {super().__repr__()}}}" + return f"{{{self.training}, {super().__repr__()}}}" diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index a74da96a8dce4..7dee679cfb582 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -67,7 +67,7 @@ def result_reduce_ddp_fn(strategy): metric_b = metric_b.to(f"cuda:{rank}") metric_c = metric_c.to(f"cuda:{rank}") - result = _ResultCollection(True, torch.device(f"cuda:{rank}")) + result = _ResultCollection(True) for _ in range(3): cumulative_sum = 0 @@ -107,7 +107,7 @@ def test_result_metric_integration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = _ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True) for _ in range(3): cumulative_sum = 0 @@ -148,7 +148,6 @@ def test_result_metric_integration(): assert repr(result) == ( "{" "True, " - "device(type='cpu'), " "{'h.a': _ResultMetric('a', value=DummyMetric()), " "'h.b': _ResultMetric('b', value=DummyMetric()), " "'h.c': _ResultMetric('c', value=DummyMetric())" @@ -157,7 +156,7 @@ def test_result_metric_integration(): def test_result_collection_simple_loop(): - result = _ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True) current_fx_name = None batch_idx = None @@ -205,7 +204,7 @@ def my_sync_dist(x, *_, **__): def test_result_collection_restoration(tmpdir): """This test make sure metrics are properly reloaded on failure.""" - result = _ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 4a2f6a31bd6ad..2b050d0b06eea 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -204,9 +204,9 @@ def on_validation_epoch_end(self) -> None: assert set(trainer.callback_metrics) == {"val_loss", "val_loss_epoch"} # make sure values are correct - assert trainer.logged_metrics["val_loss_epoch"] == model.manual_epoch_end_mean - assert trainer.callback_metrics["val_loss_epoch"] == model.manual_epoch_end_mean - assert trainer.callback_metrics["val_loss"] == model.manual_epoch_end_mean + assert torch.allclose(trainer.logged_metrics["val_loss_epoch"], model.manual_epoch_end_mean) + assert torch.allclose(trainer.callback_metrics["val_loss_epoch"], model.manual_epoch_end_mean) + assert torch.allclose(trainer.callback_metrics["val_loss"], model.manual_epoch_end_mean) assert trainer.logged_metrics["val_loss_step"] == model.val_losses[-1] diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index e165629dc2bef..4771d74b84259 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -604,7 +604,7 @@ def test_result_collection_batch_size_extraction(): fx_name = "training_step" log_val = torch.tensor(7.0) - results = _ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True) results.batch = torch.randn(1, 4) train_mse = MeanSquaredError() train_mse(torch.randn(4, 5), torch.randn(4, 5)) @@ -614,7 +614,7 @@ def test_result_collection_batch_size_extraction(): assert isinstance(results["training_step.mse"].value, MeanSquaredError) assert results["training_step.log_val"].value == log_val - results = _ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True) results.batch = torch.randn(1, 4) results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True) assert results.batch_size == 1 @@ -623,7 +623,7 @@ def test_result_collection_batch_size_extraction(): def test_result_collection_no_batch_size_extraction(): - results = _ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True) results.batch = torch.randn(1, 4) fx_name = "training_step" batch_size = 10 From 6773cf38b743a611ed5d8371c287000849ee8af8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 15 Apr 2023 05:54:44 +0200 Subject: [PATCH 09/15] Don't move to device in loops --- src/lightning/pytorch/loops/evaluation_loop.py | 3 --- src/lightning/pytorch/loops/fit_loop.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 32e3472d03473..90889e1aefdc0 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -280,9 +280,6 @@ def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_start`` hooks.""" trainer = self.trainer - assert self._results is not None - self._results.to(device=trainer.lightning_module.device) - hook_name = "on_test_start" if trainer.testing else "on_validation_start" call._call_callback_hooks(trainer, hook_name, *args, **kwargs) call._call_lightning_module_hook(trainer, hook_name, *args, **kwargs) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index c317e92840abe..16c11ba45c677 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -309,8 +309,6 @@ def on_run_start(self) -> None: self._data_fetcher = _select_data_fetcher(trainer) - self._results.to(device=trainer.lightning_module.device) - call._call_callback_hooks(trainer, "on_train_start") call._call_lightning_module_hook(trainer, "on_train_start") call._call_strategy_hook(trainer, "on_train_start") From 862aa8dd3c0e8eac0f547e48882b2bec3f097afc Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Tue, 18 Apr 2023 12:59:00 -0700 Subject: [PATCH 10/15] Update ddp.py --- src/lightning/pytorch/strategies/ddp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 0e2eb135b2589..1891695ec8728 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -47,6 +47,7 @@ from lightning.pytorch.utilities.exceptions import _augment_message from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_only from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep +from contextlib import nullcontext if torch.distributed.is_available(): from torch.distributed.algorithms.model_averaging.averagers import ModelAverager @@ -183,7 +184,11 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") - with torch.cuda.stream(torch.cuda.Stream()): + if torch.cuda.is_available(): + ctx = torch.cuda.stream(torch.cuda.Stream()) + else: + ctx = nullcontext() + with ctx: ddp_model = DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) return ddp_model From c0a7337af79ef27d48e57215e5ba41e6ec3d8466 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Apr 2023 19:59:52 +0000 Subject: [PATCH 11/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/strategies/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 1891695ec8728..1af841b9120c0 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from contextlib import nullcontext from datetime import timedelta from typing import Any, Callable, Dict, List, Literal, Optional, Union @@ -47,7 +48,6 @@ from lightning.pytorch.utilities.exceptions import _augment_message from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_only from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep -from contextlib import nullcontext if torch.distributed.is_available(): from torch.distributed.algorithms.model_averaging.averagers import ModelAverager From 4aef258dd9b24f8190ae44ff835dc8d42dfeef5f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Apr 2023 22:47:51 +0000 Subject: [PATCH 12/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/strategies/ddp.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 6de8d1bfb82f9..e101c3ac7b4cb 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -183,10 +183,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") - if torch.cuda.is_available(): - ctx = torch.cuda.stream(torch.cuda.Stream()) - else: - ctx = nullcontext() + ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext() with ctx: ddp_model = DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) return ddp_model From 684057c4de97874a095fec3dd6290d551103be07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 25 Apr 2023 01:04:26 +0200 Subject: [PATCH 13/15] Add comment about stream. Port to fabric --- src/lightning/fabric/strategies/ddp.py | 9 +++++++-- src/lightning/pytorch/strategies/ddp.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 1dee4d3d3b7ba..5998cd638b556 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from datetime import timedelta from typing import Any, Dict, Generator, List, Literal, Optional, Union @@ -113,7 +113,12 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" - return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs) + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext() + with ctx: + return DistributedDataParallel( + module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs + ) def module_to_device(self, module: Module) -> None: module.to(self.root_device) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index e101c3ac7b4cb..8ac5b16c6ef61 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -183,10 +183,10 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") + # https://pytorch.org/docs/stable/notes/cuda.html#id5 ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext() with ctx: - ddp_model = DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) - return ddp_model + return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) def setup_distributed(self) -> None: log.debug(f"{self.__class__.__name__}: setting up distributed...") From c02173fc51fc5c7addd147f8ab27253bda245a25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 25 Apr 2023 01:04:33 +0200 Subject: [PATCH 14/15] CHANGELOG --- src/lightning/fabric/CHANGELOG.md | 3 +++ src/lightning/pytorch/CHANGELOG.md | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 507cb7936f7b3..471e50a0dd8b6 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424)) +- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) + + ### Changed - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fc3bf6ba9b4a0..c33e4efbf8bb9 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Update `LightningDataModule.from_datasets` to support arbitrary iterables ([#17402](https://github.com/Lightning-AI/lightning/pull/17402)) +- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) @@ -53,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Increased the minimum XLA requirement to 1.13 ([#17368](https://github.com/Lightning-AI/lightning/pull/17368)) + +- `self.log`ed tensors are now kept in the original device to reduce unnecessary host-to-device synchronizations ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) + ### Deprecated - From 973f6e4dfd0046d4815aba9fd4f00c43dd3df601 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 25 Apr 2023 01:19:38 +0200 Subject: [PATCH 15/15] Remove DeviceDtypeModuleMixin interface --- .../pytorch/trainer/connectors/logger_connector/result.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 1407577b94b82..d7446d1d6f382 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -23,7 +23,6 @@ from lightning.fabric.utilities import move_data_to_device from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars -from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch.utilities.data import extract_batch_size from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -180,7 +179,7 @@ def is_custom_reduction(self) -> bool: return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction or self.is_sum_reduction) -class _ResultMetric(Metric, _DeviceDtypeModuleMixin): # type: ignore[misc] # torchmetrics methods should return Self +class _ResultMetric(Metric): """Wraps the value provided to `:meth:`~lightning.pytorch.core.module.LightningModule.log`""" def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: