Skip to content

Commit

Permalink
Remove deprecated sync_batchnorm and num_nodes attributes in DDP plug…
Browse files Browse the repository at this point in the history
…ins (#10357)

* Remove deprecated sync_batchnorm and num_nodes attributes in DDPPlugin

Part of #10312

test_v1_6_0_ddp_num_nodes()
test_v1_6_0_ddp_sync_batchnorm()

* Remove deprecated sync_batchnorm and num_nodes attributes in DDPPlugin

Part of #10312

test_v1_6_0_ddp_num_nodes()
test_v1_6_0_ddp_sync_batchnorm()

* remove deprecation warnings

* apply removal to spawn plugin

* update changelog

* remove num_nodes in deepspeed

* remove unused imports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 5, 2021
1 parent 037fd5e commit 9c4112c
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 58 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test` and `has_teardown_predict` datamodule lifecycle properties ([#10350](https://github.com/PyTorchLightning/pytorch-lightning/pull/10350))


- Removed deprecated property `is_slurm_managing_tasks` from AcceleratorConnector ([#10353](https://github.com/PyTorchLightning/pytorch-lightning/pull/10353))
- Removed deprecated arguments `num_nodes` and `sync_batchnorm` from `DDPPlugin`, `DDPSpawnPlugin`, `DeepSpeedPlugin` ([#10357](https://github.com/PyTorchLightning/pytorch-lightning/pull/10357))


- Removed deprecated property `is_slurm_managing_tasks` from AcceleratorConnector ([#10353](https://github.com/PyTorchLightning/pytorch-lightning/pull/10353))


### Fixed
Expand Down
16 changes: 2 additions & 14 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,8 @@ class DDPPlugin(ParallelPlugin):
def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
Expand All @@ -110,18 +108,8 @@ def __init__(
checkpoint_io=checkpoint_io,
)
self.interactive_ddp_procs = []
if num_nodes is not None:
rank_zero_deprecation(
"Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
" Notice that it will be overriden by the trainer setting."
)
self._num_nodes = num_nodes or 1
if sync_batchnorm is not None:
rank_zero_deprecation(
"Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
" Notice that it will be overriden by the trainer setting."
)
self._sync_batchnorm = sync_batchnorm or False
self._num_nodes = 1
self._sync_batchnorm = False
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._ddp_kwargs = kwargs
self._task_idx = None
Expand Down
23 changes: 3 additions & 20 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,7 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
Expand Down Expand Up @@ -69,10 +64,8 @@ class DDPSpawnPlugin(ParallelPlugin):
def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
Expand All @@ -83,18 +76,8 @@ def __init__(
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
)
if num_nodes is not None:
rank_zero_deprecation(
"Argument `num_nodes` in `DDPSpawnPlugin` is deprecated in v1.4, and will be removed in v1.6. "
"Notice that it will be overriden by the trainer setting."
)
self._num_nodes = num_nodes or 1
if sync_batchnorm is not None:
rank_zero_deprecation(
"Argument `sync_batchnorm` in `DDPSpawnPlugin` is deprecated in v1.4, and will be removed in v1.6. "
"Notice that it will be overriden by the trainer setting."
)
self._sync_batchnorm = sync_batchnorm or False
self._num_nodes = 1
self._sync_batchnorm = False
self._ddp_kwargs = kwargs
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.mp_queue = None
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(
logging_batch_size_per_gpu: Union[str, int] = "auto",
config: Optional[Union[Path, str, dict]] = None,
logging_level: int = logging.WARN,
num_nodes: Optional[int] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
loss_scale: float = 0,
Expand Down Expand Up @@ -273,7 +272,6 @@ def __init__(

super().__init__(
parallel_devices=parallel_devices,
num_nodes=num_nodes,
cluster_environment=cluster_environment,
)

Expand Down
22 changes: 1 addition & 21 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.plugins.training_type import DDPPlugin
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary
Expand All @@ -40,26 +40,6 @@ def transfer_batch_to_device(self, batch, device):
trainer.fit(OldModel())


def test_v1_6_0_ddp_num_nodes():
with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"):
DDPPlugin(num_nodes=1)


def test_v1_6_0_ddp_sync_batchnorm():
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
DDPPlugin(sync_batchnorm=False)


def test_v1_6_0_ddp_spawn_num_nodes():
with pytest.deprecated_call(match="Argument `num_nodes` in `DDPSpawnPlugin` is deprecated in v1.4"):
DDPSpawnPlugin(num_nodes=1)


def test_v1_6_0_ddp_spawn_sync_batchnorm():
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPSpawnPlugin` is deprecated in v1.4"):
DDPSpawnPlugin(sync_batchnorm=False)


def test_v1_6_0_reload_dataloaders_every_epoch(tmpdir):
model = BoringModel()

Expand Down

0 comments on commit 9c4112c

Please sign in to comment.