From 703d4cff801a7bb1e15d009f8ee1de302ff1faf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 20 Dec 2022 16:58:28 +0100 Subject: [PATCH 1/6] Deprecate the `HorovodStrategy` --- src/pytorch_lightning/strategies/horovod.py | 19 +++++++++++++++++-- .../connectors/accelerator_connector.py | 5 +++-- src/pytorch_lightning/utilities/__init__.py | 1 - src/pytorch_lightning/utilities/imports.py | 1 - tests/README.md | 3 +-- tests/tests_lite/conftest.py | 1 - tests/tests_pytorch/conftest.py | 2 +- .../deprecated_api/test_remove_1-10.py | 6 ++++++ tests/tests_pytorch/helpers/runif.py | 19 +++---------------- tests/tests_pytorch/models/test_horovod.py | 10 +++++----- tests/tests_pytorch/utilities/test_imports.py | 3 ++- 11 files changed, 38 insertions(+), 32 deletions(-) diff --git a/src/pytorch_lightning/strategies/horovod.py b/src/pytorch_lightning/strategies/horovod.py index 105dd3d049f1a..3aae34288a620 100644 --- a/src/pytorch_lightning/strategies/horovod.py +++ b/src/pytorch_lightning/strategies/horovod.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +from lightning_utilities.core.imports import module_available from torch import Tensor from torch.optim import Optimizer @@ -29,12 +30,22 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only +_HOROVOD_AVAILABLE = module_available("horovod.torch") +_HOROVOD_NCCL_AVAILABLE = False if _HOROVOD_AVAILABLE: import horovod.torch as hvd + try: + + # `nccl_built` returns an integer + _HOROVOD_NCCL_AVAILABLE = bool(hvd.nccl_built()) + except AttributeError: + # AttributeError can be raised if MPI is not available: + # https://github.com/horovod/horovod/blob/v0.23.0/horovod/torch/__init__.py#L33-L34 + pass + class HorovodStrategy(ParallelStrategy): """Plugin for Horovod distributed training integration.""" @@ -48,6 +59,10 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): + rank_zero_deprecation( + "`The `HorovodStrategy`: `Trainer(strategy='horovod')` has been deprecated in v1.9.0 and will be removed" + " in v1.10.0. You can try using the `Trainer(strategy='ddp')` instead." + ) super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index 1a6193c04653f..9081f09ccd76a 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -78,9 +78,10 @@ TPUSpawnStrategy, ) from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES +from pytorch_lightning.strategies.horovod import _HOROVOD_AVAILABLE from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _IPU_AVAILABLE +from pytorch_lightning.utilities.imports import _IPU_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn log = logging.getLogger(__name__) @@ -653,7 +654,7 @@ def _handle_horovod(self) -> None: if not _HOROVOD_AVAILABLE: raise MisconfigurationException( 'Requested `strategy="horovod"`, but Horovod is not installed.' - "Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]" + " Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]" ) hvd.init() diff --git a/src/pytorch_lightning/utilities/__init__.py b/src/pytorch_lightning/utilities/__init__.py index 27107bc8b81f8..dee1e3363a739 100644 --- a/src/pytorch_lightning/utilities/__init__.py +++ b/src/pytorch_lightning/utilities/__init__.py @@ -23,7 +23,6 @@ from pytorch_lightning.utilities.grads import grad_norm # noqa: F401 from pytorch_lightning.utilities.imports import ( # noqa: F401 _HIVEMIND_AVAILABLE, - _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _OMEGACONF_AVAILABLE, diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index d365135e81364..6c8ee24159bb7 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -27,7 +27,6 @@ _DALI_AVAILABLE = module_available("nvidia.dali") _HABANA_FRAMEWORK_AVAILABLE = package_available("habana_frameworks") _HIVEMIND_AVAILABLE = package_available("hivemind") -_HOROVOD_AVAILABLE = module_available("horovod.torch") _KINETO_AVAILABLE = torch.profiler.kineto_available() _OMEGACONF_AVAILABLE = package_available("omegaconf") _POPTORCH_AVAILABLE = package_available("poptorch") diff --git a/tests/README.md b/tests/README.md index 723a47d4a483a..ffc5181b7cea8 100644 --- a/tests/README.md +++ b/tests/README.md @@ -57,7 +57,6 @@ To test models that require GPU make sure to run the above command on a GPU mach The GPU machine must have at least 2 GPUs to run distributed tests. Note that this setup will not run tests that require specific packages installed -such as Horovod, FairScale, NVIDIA/apex, NVIDIA/DALI, etc. You can rely on our CI to make sure all these tests pass. ### Standalone Tests @@ -72,7 +71,7 @@ There are certain standalone tests, which you can run using: ## Running Coverage -Make sure to run coverage on a GPU machine with at least 2 GPUs and NVIDIA apex installed. +Make sure to run coverage on a GPU machine with at least 2 GPUs. ```bash cd pytorch-lightning diff --git a/tests/tests_lite/conftest.py b/tests/tests_lite/conftest.py index af023504b5473..b668bd581757d 100644 --- a/tests/tests_lite/conftest.py +++ b/tests/tests_lite/conftest.py @@ -51,7 +51,6 @@ def restore_env_variables(): "MASTER_PORT", "PL_GLOBAL_SEED", "PL_SEED_WORKERS", - "HOROVOD_FUSION_THRESHOLD", "RANK", # set by DeepSpeed "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy "CUDA_MODULE_LOADING", # leaked since PyTorch 1.13 diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 39b97cb16d006..21ebd85548168 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -69,7 +69,7 @@ def restore_env_variables(): "WANDB_MODE", "WANDB_REQUIRE_SERVICE", "WANDB_SERVICE", - "HOROVOD_FUSION_THRESHOLD", + "HOROVOD_FUSION_THRESHOLD", # set by HorovodStrategy # TODO: remove in v1.10.0 "RANK", # set by DeepSpeed "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy "CUDA_MODULE_LOADING", # leaked since PyTorch 1.13 diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 715475bcacb4d..98eff8b5b7e19 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -403,3 +403,9 @@ def optimizer_step( trainer = Trainer() with pytest.deprecated_call(match="amp_backend` will not be supported"): trainer.amp_backend + + +@RunIf(horovod=True) +def test_horovod_deprecation_warnings(*_): + with pytest.deprecated_call(match=r"horovod'\)` has been deprecated in v1.9"): + Trainer(strategy="horovod") diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 72c529dcc91b6..17b380ed51e58 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -29,9 +29,9 @@ from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.colossalai import _COLOSSALAI_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE +from pytorch_lightning.strategies.horovod import _HOROVOD_AVAILABLE, _HOROVOD_NCCL_AVAILABLE from pytorch_lightning.utilities.imports import ( _HIVEMIND_AVAILABLE, - _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _OMEGACONF_AVAILABLE, @@ -40,19 +40,6 @@ ) from tests_pytorch.helpers.datamodules import _SKLEARN_AVAILABLE -_HOROVOD_NCCL_AVAILABLE = False -if _HOROVOD_AVAILABLE: - import horovod - - try: - - # `nccl_built` returns an integer - _HOROVOD_NCCL_AVAILABLE = bool(horovod.torch.nccl_built()) - except AttributeError: - # AttributeError can be raised if MPI is not available: - # https://github.com/horovod/horovod/blob/v0.23.0/horovod/torch/__init__.py#L33-L34 - pass - class RunIf: """RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: @@ -77,8 +64,8 @@ def __new__( ipu: bool = False, hpu: bool = False, mps: Optional[bool] = None, - horovod: bool = False, - horovod_nccl: bool = False, + horovod: bool = False, # TODO: remove in v1.10.0 + horovod_nccl: bool = False, # TODO: remove in v1.10.0 skip_windows: bool = False, standalone: bool = False, fairscale: bool = False, diff --git a/tests/tests_pytorch/models/test_horovod.py b/tests/tests_pytorch/models/test_horovod.py index 7963bde389f4d..950113b7c69cb 100644 --- a/tests/tests_pytorch/models/test_horovod.py +++ b/tests/tests_pytorch/models/test_horovod.py @@ -28,7 +28,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.strategies.horovod import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.advanced_models import BasicGAN from tests_pytorch.helpers.runif import RunIf @@ -40,7 +40,7 @@ @RunIf(min_cuda_gpus=1, horovod=True) def test_nccl_is_available_on_gpu_environment(): - from tests_pytorch.helpers.runif import _HOROVOD_NCCL_AVAILABLE + from pytorch_lightning.strategies.horovod import _HOROVOD_NCCL_AVAILABLE # the GPU environment should always install Horovod NCCL assert _HOROVOD_NCCL_AVAILABLE @@ -293,7 +293,6 @@ def get_optimizer_params(optimizer): assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) -# todo: need to be fixed :] @pytest.mark.skip(reason="TODO: CI agent.jobstatus=Succeeded: Permission denied") @RunIf(horovod=True, skip_windows=True) def test_result_reduce_horovod(tmpdir): @@ -413,8 +412,9 @@ def configure_optimizers(self): num_workers = 8 init_lr = 0.1 * num_workers - with patch("horovod.torch.size", return_value=8): - + with patch("horovod.torch.size", return_value=8), pytest.deprecated_call( + match=r"horovod'\)` has been deprecated in v1.9" + ): # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5, limit_train_batches=0.2, strategy="horovod" diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 3a22e8aeb6e7f..23e5e60a2965c 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -25,7 +25,8 @@ from pytorch_lightning.plugins.precision.apex_amp import _APEX_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE +from pytorch_lightning.strategies.horovod import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE from tests_pytorch.helpers.runif import RunIf From 1b85d7c3cbb9585fa54598bfccfff4c372403f65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 20 Dec 2022 17:09:31 +0100 Subject: [PATCH 2/6] Fixes --- src/pytorch_lightning/strategies/horovod.py | 15 +++++---------- .../trainer/connectors/accelerator_connector.py | 2 +- tests/tests_pytorch/helpers/runif.py | 15 ++++++++++++++- .../models/data/horovod/train_default_model.py | 2 +- tests/tests_pytorch/models/test_horovod.py | 2 +- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/pytorch_lightning/strategies/horovod.py b/src/pytorch_lightning/strategies/horovod.py index 3aae34288a620..5e4e289ce7003 100644 --- a/src/pytorch_lightning/strategies/horovod.py +++ b/src/pytorch_lightning/strategies/horovod.py @@ -33,19 +33,9 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only _HOROVOD_AVAILABLE = module_available("horovod.torch") -_HOROVOD_NCCL_AVAILABLE = False if _HOROVOD_AVAILABLE: import horovod.torch as hvd - try: - - # `nccl_built` returns an integer - _HOROVOD_NCCL_AVAILABLE = bool(hvd.nccl_built()) - except AttributeError: - # AttributeError can be raised if MPI is not available: - # https://github.com/horovod/horovod/blob/v0.23.0/horovod/torch/__init__.py#L33-L34 - pass - class HorovodStrategy(ParallelStrategy): """Plugin for Horovod distributed training integration.""" @@ -63,6 +53,11 @@ def __init__( "`The `HorovodStrategy`: `Trainer(strategy='horovod')` has been deprecated in v1.9.0 and will be removed" " in v1.10.0. You can try using the `Trainer(strategy='ddp')` instead." ) + if not _HOROVOD_AVAILABLE: + raise MisconfigurationException( + 'Requested `strategy="horovod"`, but Horovod is not installed.' + " Install with `HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]`" + ) super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index 9081f09ccd76a..03f7d53732a6e 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -654,7 +654,7 @@ def _handle_horovod(self) -> None: if not _HOROVOD_AVAILABLE: raise MisconfigurationException( 'Requested `strategy="horovod"`, but Horovod is not installed.' - " Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]" + " Install with `HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]`" ) hvd.init() diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 17b380ed51e58..2bfb44ca5e117 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -29,7 +29,7 @@ from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.colossalai import _COLOSSALAI_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE -from pytorch_lightning.strategies.horovod import _HOROVOD_AVAILABLE, _HOROVOD_NCCL_AVAILABLE +from pytorch_lightning.strategies.horovod import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.imports import ( _HIVEMIND_AVAILABLE, _HPU_AVAILABLE, @@ -40,6 +40,19 @@ ) from tests_pytorch.helpers.datamodules import _SKLEARN_AVAILABLE +_HOROVOD_NCCL_AVAILABLE = False +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd + + try: + + # `nccl_built` returns an integer + _HOROVOD_NCCL_AVAILABLE = bool(hvd.nccl_built()) + except AttributeError: + # AttributeError can be raised if MPI is not available: + # https://github.com/horovod/horovod/blob/v0.23.0/horovod/torch/__init__.py#L33-L34 + pass + class RunIf: """RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: diff --git a/tests/tests_pytorch/models/data/horovod/train_default_model.py b/tests/tests_pytorch/models/data/horovod/train_default_model.py index 26e1e8c2f8f95..9e2fb34355e4c 100644 --- a/tests/tests_pytorch/models/data/horovod/train_default_model.py +++ b/tests/tests_pytorch/models/data/horovod/train_default_model.py @@ -29,7 +29,7 @@ from pytorch_lightning import Trainer # noqa: E402 from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402 -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE # noqa: E402 +from pytorch_lightning.strategies.horovod import _HOROVOD_AVAILABLE # noqa: E402 if _HOROVOD_AVAILABLE: import horovod.torch as hvd diff --git a/tests/tests_pytorch/models/test_horovod.py b/tests/tests_pytorch/models/test_horovod.py index 950113b7c69cb..75ba489f4232b 100644 --- a/tests/tests_pytorch/models/test_horovod.py +++ b/tests/tests_pytorch/models/test_horovod.py @@ -40,7 +40,7 @@ @RunIf(min_cuda_gpus=1, horovod=True) def test_nccl_is_available_on_gpu_environment(): - from pytorch_lightning.strategies.horovod import _HOROVOD_NCCL_AVAILABLE + from tests_pytorch.helpers.runif import _HOROVOD_NCCL_AVAILABLE # the GPU environment should always install Horovod NCCL assert _HOROVOD_NCCL_AVAILABLE From 44d086912d8cc18627b467546acbb3822f21c26f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 20 Dec 2022 17:20:44 +0100 Subject: [PATCH 3/6] Fixes --- tests/tests_pytorch/models/test_horovod.py | 74 ++++++++++++---------- 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/tests/tests_pytorch/models/test_horovod.py b/tests/tests_pytorch/models/test_horovod.py index 75ba489f4232b..fa9f15572719c 100644 --- a/tests/tests_pytorch/models/test_horovod.py +++ b/tests/tests_pytorch/models/test_horovod.py @@ -165,19 +165,20 @@ def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir): _run_horovod(trainer_options) -@RunIf(horovod=True, skip_windows=True, min_cuda_gpus=2) +@RunIf(horovod=True, skip_windows=True, min_cuda_gpus=1) def test_horovod_raises_unsupported_accumulate_grad_batches(tmpdir): """Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod Strategy on multi-gpus.""" model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - enable_progress_bar=False, - accumulate_grad_batches={0: 4, 2: 2}, - accelerator="auto", - devices=2, - strategy="horovod", - ) + with pytest.deprecated_call(match=r"horovod'\)` has been deprecated in v1.9"): + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + accumulate_grad_batches={0: 4, 2: 2}, + accelerator="auto", + devices=1, + strategy="horovod", + ) with pytest.raises(MisconfigurationException, match="Horovod.*does not support.*accumulate_grad_batches"): trainer.fit(model) @@ -267,14 +268,15 @@ def test_horovod_multi_optimizer(tmpdir): model = BasicGAN() # fit model - trainer = Trainer( - default_root_dir=str(tmpdir), - enable_progress_bar=False, - max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.2, - strategy="horovod", - ) + with pytest.deprecated_call(match=r"horovod'\)` has been deprecated in v1.9"): + trainer = Trainer( + default_root_dir=str(tmpdir), + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + strategy="horovod", + ) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -326,15 +328,16 @@ def training_epoch_end(self, outputs) -> None: model = TestModel() model.val_dataloader = None - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - log_every_n_steps=1, - enable_model_summary=False, - logger=False, - ) + with pytest.deprecated_call(match=r"horovod'\)` has been deprecated in v1.9"): + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + enable_model_summary=False, + logger=False, + ) trainer.fit(model) @@ -360,7 +363,8 @@ def sk_metric(preds, target): target = torch.randint(high=2, size=(num_batches, batch_size)) def _compute_batch(): - trainer = Trainer(fast_dev_run=True, strategy="horovod", logger=False) + with pytest.deprecated_call(match=r"horovod'\)` has been deprecated in v1.9"): + trainer = Trainer(fast_dev_run=True, strategy="horovod", logger=False) assert isinstance(trainer.accelerator, CPUAccelerator) # TODO: test that we selected the correct strategy based on horovod flags @@ -412,13 +416,15 @@ def configure_optimizers(self): num_workers = 8 init_lr = 0.1 * num_workers - with patch("horovod.torch.size", return_value=8), pytest.deprecated_call( - match=r"horovod'\)` has been deprecated in v1.9" - ): - # fit model - trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5, limit_train_batches=0.2, strategy="horovod" - ) + with patch("horovod.torch.size", return_value=8): + with pytest.deprecated_call(match=r"horovod'\)` has been deprecated in v1.9"): + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.5, + limit_train_batches=0.2, + strategy="horovod", + ) trainer.fit(model) adjusted_lr1 = [pg["lr"] for pg in trainer.optimizers[0].param_groups][0] From e24fed8dc7bcd3108c6b1fb3be9ec9b44e4c2711 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 20 Dec 2022 17:31:55 +0100 Subject: [PATCH 4/6] Catch --- tests/tests_pytorch/models/test_horovod.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_horovod.py b/tests/tests_pytorch/models/test_horovod.py index fa9f15572719c..fc71dc42cba40 100644 --- a/tests/tests_pytorch/models/test_horovod.py +++ b/tests/tests_pytorch/models/test_horovod.py @@ -260,7 +260,8 @@ def validation_step(self, batch, *args, **kwargs): devices=2, strategy="horovod", ) - tpipes.run_model_test_without_loggers(trainer_options, model) + with pytest.deprecated_call(match=r"horovod'\)` has been deprecated in v1.9"): + tpipes.run_model_test_without_loggers(trainer_options, model) @RunIf(horovod=True, skip_windows=True) From c3eb7135bc072124e5f24cf8dbaeaf9cb7896d14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 20 Dec 2022 17:33:09 +0100 Subject: [PATCH 5/6] CHANGELOG --- src/pytorch_lightning/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 3ea2888ffdaf3..b1c0a812d4738 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -86,6 +86,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Deprecates the `pytorch_lightning.utilities.enum.sAMPType` enum * Deprecates the `DeepSpeedPrecisionPlugin(amp_type=..., amp_level=...)` arguments +- `horovod` deprecation ([#16141](https://github.com/PyTorchLightning/pytorch-lightning/pull/16141)) + * Deprecated `Trainer(strategy="horovod")` + * Deprecated the `HorovodStrategy` class + ### Removed From dcf914b3887496c289c2ded764f1ad0b689c3b80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 20 Dec 2022 17:36:49 +0100 Subject: [PATCH 6/6] Docs --- docs/source-pytorch/accelerators/gpu_faq.rst | 4 +- .../accelerators/gpu_intermediate.rst | 41 +------------------ docs/source-pytorch/api_references.rst | 1 - docs/source-pytorch/common/trainer.rst | 1 - docs/source-pytorch/extensions/strategy.rst | 3 -- .../source-pytorch/starter/lightning_lite.rst | 2 +- 6 files changed, 3 insertions(+), 49 deletions(-) diff --git a/docs/source-pytorch/accelerators/gpu_faq.rst b/docs/source-pytorch/accelerators/gpu_faq.rst index 8302665591f4b..89ea8ec3b8f8a 100644 --- a/docs/source-pytorch/accelerators/gpu_faq.rst +++ b/docs/source-pytorch/accelerators/gpu_faq.rst @@ -20,7 +20,7 @@ Let's say you have a batch size of 7 in your dataloader. def train_dataloader(self): return Dataset(..., batch_size=7) -In DDP, DDP_SPAWN, Deepspeed, DDP_SHARDED, or Horovod your effective batch size will be 7 * devices * num_nodes. +In DDP, DDP_SPAWN, Deepspeed, DDP_SHARDED your effective batch size will be 7 * devices * num_nodes. .. code-block:: python @@ -28,13 +28,11 @@ In DDP, DDP_SPAWN, Deepspeed, DDP_SHARDED, or Horovod your effective batch size Trainer(accelerator="gpu", devices=8, strategy="ddp") Trainer(accelerator="gpu", devices=8, strategy="ddp_spawn") Trainer(accelerator="gpu", devices=8, strategy="ddp_sharded") - Trainer(accelerator="gpu", devices=8, strategy="horovod") # effective batch size = 7 * 8 * 10 Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp") Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp_spawn") Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp_sharded") - Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="horovod") .. note:: Huge batch sizes are actually really bad for convergence. Check out: diff --git a/docs/source-pytorch/accelerators/gpu_intermediate.rst b/docs/source-pytorch/accelerators/gpu_intermediate.rst index 9e2e7a4071ce0..b8b5822c0aa35 100644 --- a/docs/source-pytorch/accelerators/gpu_intermediate.rst +++ b/docs/source-pytorch/accelerators/gpu_intermediate.rst @@ -25,7 +25,6 @@ Lightning supports multiple ways of doing distributed training. - Regular (``strategy='ddp'``) - Spawn (``strategy='ddp_spawn'``) - Notebook/Fork (``strategy='ddp_notebook'``) -- Horovod (``strategy='horovod'``) (multi-machine, multi-gpu, configured at runtime) - Bagua (``strategy='bagua'``) (multiple-gpus across many machines with advanced training algorithms) .. note:: @@ -236,44 +235,6 @@ Comparison of DDP variants and tradeoffs - Fast -Horovod -^^^^^^^ -`Horovod `_ allows the same training script to be used for single-GPU, -multi-GPU, and multi-node training. - -Like Distributed Data Parallel, every process in Horovod operates on a single GPU with a fixed -subset of the data. Gradients are averaged across all GPUs in parallel during the backward pass, -then synchronously applied before beginning the next step. - -The number of worker processes is configured by a driver application (`horovodrun` or `mpirun`). In -the training script, Horovod will detect the number of workers from the environment, and automatically -scale the learning rate to compensate for the increased total batch size. - -Horovod can be configured in the training script to run with any number of GPUs / processes as follows: - -.. code-block:: python - - # train Horovod on GPU (number of GPUs / machines provided on command-line) - trainer = Trainer(strategy="horovod", accelerator="gpu", devices=1) - - # train Horovod on CPU (number of processes / machines provided on command-line) - trainer = Trainer(strategy="horovod") - -When starting the training job, the driver application will then be used to specify the total -number of worker processes: - -.. code-block:: bash - - # run training with 4 GPUs on a single machine - horovodrun -np 4 python train.py - - # run training with 8 GPUs on two machines (4 GPUs each) - horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py - -See the official `Horovod documentation `_ for details -on installation and performance tuning. - - Bagua ^^^^^ `Bagua `_ is a deep learning training acceleration framework which supports @@ -284,7 +245,7 @@ multiple advanced distributed training algorithms including: - `ByteGrad `_ and `QAdam `_ for low precision communication, where data is compressed into low precision before communication. - `Asynchronous Model Average `_ for asynchronous communication, where workers are not required to be synchronized in the same iteration in a lock-step style. -By default, Bagua uses *Gradient AllReduce* algorithm, which is also the algorithm implemented in Distributed Data Parallel and Horovod, +By default, Bagua uses *Gradient AllReduce* algorithm, which is also the algorithm implemented in DDP, but Bagua can usually produce a higher training throughput due to its backend written in Rust. .. code-block:: python diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 141cfe0d67615..7394156eb39da 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -295,7 +295,6 @@ strategies DataParallelStrategy DeepSpeedStrategy HivemindStrategy - HorovodStrategy HPUParallelStrategy IPUStrategy ParallelStrategy diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 8d5e35206b988..613da787c4c1a 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -424,7 +424,6 @@ deterministic This flag sets the ``torch.backends.cudnn.deterministic`` flag. Might make your system slower, but ensures reproducibility. -Also sets ``$HOROVOD_FUSION_THRESHOLD=0``. For more info check `PyTorch docs `_. diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index 807de1b02e47c..3d97a14946ebd 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -102,9 +102,6 @@ The below table lists all relevant strategies available in Lightning with their * - deepspeed - :class:`~pytorch_lightning.strategies.DeepSpeedStrategy` - Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. :ref:`Learn more. ` - * - horovod - - :class:`~pytorch_lightning.strategies.HorovodStrategy` - - Strategy for Horovod distributed training integration. :ref:`Learn more. ` * - hpu_parallel - :class:`~pytorch_lightning.strategies.HPUParallelStrategy` - Strategy for distributed training on multiple HPU devices. :doc:`Learn more. <../accelerators/hpu>` diff --git a/docs/source-pytorch/starter/lightning_lite.rst b/docs/source-pytorch/starter/lightning_lite.rst index bc097a02571d5..de9a7ccc3d3bf 100644 --- a/docs/source-pytorch/starter/lightning_lite.rst +++ b/docs/source-pytorch/starter/lightning_lite.rst @@ -276,7 +276,7 @@ Additionally, you can pass in your custom strategy by configuring additional par lite = Lite(strategy=DeepSpeedStrategy(stage=2), accelerator="gpu", devices=2) -Support for Horovod and Fully Sharded training strategies are coming soon. +Support for Fully Sharded training strategies are coming soon. devices