Skip to content

Commit

Permalink
Patch for device placement (Reduce host and device syncs) (#17334)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
shanmugamr1992 and carmocca authored Apr 25, 2023
1 parent 039891f commit 524608c
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 34 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when calling methods on `_FabricModule` that bypass the strategy-specific wrappers ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))


- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))


### Changed

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from typing import Any, Dict, Generator, List, Literal, Optional, Union

Expand Down Expand Up @@ -113,7 +113,12 @@ def setup_environment(self) -> None:

def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs)
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext()
with ctx:
return DistributedDataParallel(
module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs
)

def module_to_device(self, module: Module) -> None:
module.to(self.root_device)
Expand Down
6 changes: 6 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Update `LightningDataModule.from_datasets` to support arbitrary iterables ([#17402](https://github.com/Lightning-AI/lightning/pull/17402))


- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))


### Changed

- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand All @@ -53,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Increased the minimum XLA requirement to 1.13 ([#17368](https://github.com/Lightning-AI/lightning/pull/17368))


- `self.log`ed tensors are now kept in the original device to reduce unnecessary host-to-device synchronizations ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))

### Deprecated

-
Expand Down
6 changes: 1 addition & 5 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,7 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = (
value.clone().detach().to(self.device)
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device)
)
value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device)
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
Expand Down
3 changes: 0 additions & 3 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,6 @@ def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
trainer = self.trainer

assert self._results is not None
self._results.to(device=trainer.lightning_module.device)

hook_name = "on_test_start" if trainer.testing else "on_validation_start"
call._call_callback_hooks(trainer, hook_name, *args, **kwargs)
call._call_lightning_module_hook(trainer, hook_name, *args, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,6 @@ def on_run_start(self) -> None:

self._data_fetcher = _select_data_fetcher(trainer)

self._results.to(device=trainer.lightning_module.device)

call._call_callback_hooks(trainer, "on_train_start")
call._call_lightning_module_hook(trainer, "on_train_start")
call._call_strategy_hook(trainer, "on_train_start")
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import nullcontext
from datetime import timedelta
from typing import Any, Callable, Dict, List, Literal, Optional, Union

Expand Down Expand Up @@ -182,7 +183,10 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if torch.cuda.is_available() else nullcontext()
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

def setup_distributed(self) -> None:
log.debug(f"{self.__class__.__name__}: setting up distributed...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch.utilities.data import extract_batch_size
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -180,7 +179,7 @@ def is_custom_reduction(self) -> bool:
return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction or self.is_sum_reduction)


class _ResultMetric(Metric, _DeviceDtypeModuleMixin): # type: ignore[misc] # torchmetrics methods should return Self
class _ResultMetric(Metric):
"""Wraps the value provided to `:meth:`~lightning.pytorch.core.module.LightningModule.log`"""

def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
Expand Down Expand Up @@ -314,10 +313,9 @@ class _ResultCollection(dict):

DATALOADER_SUFFIX = "/dataloader_idx_{}"

def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
def __init__(self, training: bool) -> None:
super().__init__()
self.training = training
self.device: Optional[Union[str, torch.device]] = device
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.dataloader_idx: Optional[int] = None
Expand Down Expand Up @@ -410,13 +408,13 @@ def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
Value can be provided as a nested collection
"""
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(self.device)
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
self[key] = metric

def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(value.to(self.device), batch_size)
result_metric.forward(value, batch_size)
result_metric.has_reset = False

@staticmethod
Expand Down Expand Up @@ -509,9 +507,6 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non
def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
"""Move all data to the given device."""
self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))

if "device" in kwargs:
self.device = kwargs["device"]
return self

def cpu(self) -> "_ResultCollection":
Expand All @@ -524,4 +519,4 @@ def __str__(self) -> str:
return f"{self.__class__.__name__}({self_str})"

def __repr__(self) -> str:
return f"{{{self.training}, {repr(self.device)}, {super().__repr__()}}}"
return f"{{{self.training}, {super().__repr__()}}}"
9 changes: 4 additions & 5 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def result_reduce_ddp_fn(strategy):
metric_b = metric_b.to(f"cuda:{rank}")
metric_c = metric_c.to(f"cuda:{rank}")

result = _ResultCollection(True, torch.device(f"cuda:{rank}"))
result = _ResultCollection(True)

for _ in range(3):
cumulative_sum = 0
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_result_metric_integration():
metric_b = DummyMetric()
metric_c = DummyMetric()

result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)

for _ in range(3):
cumulative_sum = 0
Expand Down Expand Up @@ -148,7 +148,6 @@ def test_result_metric_integration():
assert repr(result) == (
"{"
"True, "
"device(type='cpu'), "
"{'h.a': _ResultMetric('a', value=DummyMetric()), "
"'h.b': _ResultMetric('b', value=DummyMetric()), "
"'h.c': _ResultMetric('c', value=DummyMetric())"
Expand All @@ -157,7 +156,7 @@ def test_result_metric_integration():


def test_result_collection_simple_loop():
result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)
current_fx_name = None
batch_idx = None

Expand Down Expand Up @@ -205,7 +204,7 @@ def my_sync_dist(x, *_, **__):
def test_result_collection_restoration(tmpdir):
"""This test make sure metrics are properly reloaded on failure."""

result = _ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True)
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ def on_validation_epoch_end(self) -> None:
assert set(trainer.callback_metrics) == {"val_loss", "val_loss_epoch"}

# make sure values are correct
assert trainer.logged_metrics["val_loss_epoch"] == model.manual_epoch_end_mean
assert trainer.callback_metrics["val_loss_epoch"] == model.manual_epoch_end_mean
assert trainer.callback_metrics["val_loss"] == model.manual_epoch_end_mean
assert torch.allclose(trainer.logged_metrics["val_loss_epoch"], model.manual_epoch_end_mean)
assert torch.allclose(trainer.callback_metrics["val_loss_epoch"], model.manual_epoch_end_mean)
assert torch.allclose(trainer.callback_metrics["val_loss"], model.manual_epoch_end_mean)
assert trainer.logged_metrics["val_loss_step"] == model.val_losses[-1]


Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def test_result_collection_batch_size_extraction():
fx_name = "training_step"
log_val = torch.tensor(7.0)

results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
train_mse = MeanSquaredError()
train_mse(torch.randn(4, 5), torch.randn(4, 5))
Expand All @@ -615,7 +615,7 @@ def test_result_collection_batch_size_extraction():
assert isinstance(results["training_step.mse"].value, MeanSquaredError)
assert results["training_step.log_val"].value == log_val

results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True)
assert results.batch_size == 1
Expand All @@ -624,7 +624,7 @@ def test_result_collection_batch_size_extraction():


def test_result_collection_no_batch_size_extraction():
results = _ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True)
results.batch = torch.randn(1, 4)
fx_name = "training_step"
batch_size = 10
Expand Down

0 comments on commit 524608c

Please sign in to comment.