Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated sync_batchnorm and num_nodes attributes in DDP plugins #10357

Merged
merged 10 commits into from
Nov 5, 2021
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test` and `has_teardown_predict` datamodule lifecycle properties ([#10350](https://github.com/PyTorchLightning/pytorch-lightning/pull/10350))


-
- Removed deprecated arguments `num_nodes` and `sync_batchnorm` from `DDPPlugin`, `DDPSpawnPlugin`, `DeepSpeedPlugin` ([#10357](https://github.com/PyTorchLightning/pytorch-lightning/pull/10357))


### Fixed

Expand Down
16 changes: 2 additions & 14 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,8 @@ class DDPPlugin(ParallelPlugin):
def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
Expand All @@ -110,18 +108,8 @@ def __init__(
checkpoint_io=checkpoint_io,
)
self.interactive_ddp_procs = []
if num_nodes is not None:
rank_zero_deprecation(
"Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
" Notice that it will be overriden by the trainer setting."
)
self._num_nodes = num_nodes or 1
if sync_batchnorm is not None:
rank_zero_deprecation(
"Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6."
" Notice that it will be overriden by the trainer setting."
)
self._sync_batchnorm = sync_batchnorm or False
self._num_nodes = 1
self._sync_batchnorm = False
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._ddp_kwargs = kwargs
self._task_idx = None
Expand Down
17 changes: 2 additions & 15 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from pytorch_lightning.utilities import (
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
Expand Down Expand Up @@ -69,10 +68,8 @@ class DDPSpawnPlugin(ParallelPlugin):
def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
Expand All @@ -83,18 +80,8 @@ def __init__(
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
)
if num_nodes is not None:
rank_zero_deprecation(
"Argument `num_nodes` in `DDPSpawnPlugin` is deprecated in v1.4, and will be removed in v1.6. "
"Notice that it will be overriden by the trainer setting."
)
self._num_nodes = num_nodes or 1
if sync_batchnorm is not None:
rank_zero_deprecation(
"Argument `sync_batchnorm` in `DDPSpawnPlugin` is deprecated in v1.4, and will be removed in v1.6. "
"Notice that it will be overriden by the trainer setting."
)
self._sync_batchnorm = sync_batchnorm or False
self._num_nodes = 1
self._sync_batchnorm = False
self._ddp_kwargs = kwargs
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.mp_queue = None
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(
logging_batch_size_per_gpu: Union[str, int] = "auto",
config: Optional[Union[Path, str, dict]] = None,
logging_level: int = logging.WARN,
num_nodes: Optional[int] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
loss_scale: float = 0,
Expand Down Expand Up @@ -273,7 +272,6 @@ def __init__(

super().__init__(
parallel_devices=parallel_devices,
num_nodes=num_nodes,
cluster_environment=cluster_environment,
)

Expand Down
22 changes: 1 addition & 21 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.plugins.training_type import DDPPlugin
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary
Expand All @@ -40,26 +40,6 @@ def transfer_batch_to_device(self, batch, device):
trainer.fit(OldModel())


def test_v1_6_0_ddp_num_nodes():
with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"):
DDPPlugin(num_nodes=1)


def test_v1_6_0_ddp_sync_batchnorm():
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
DDPPlugin(sync_batchnorm=False)


def test_v1_6_0_ddp_spawn_num_nodes():
with pytest.deprecated_call(match="Argument `num_nodes` in `DDPSpawnPlugin` is deprecated in v1.4"):
DDPSpawnPlugin(num_nodes=1)


def test_v1_6_0_ddp_spawn_sync_batchnorm():
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPSpawnPlugin` is deprecated in v1.4"):
DDPSpawnPlugin(sync_batchnorm=False)


def test_v1_6_0_reload_dataloaders_every_epoch(tmpdir):
model = BoringModel()

Expand Down