From e7a4a122241d2cfd05ce661f34cc41354fd5c570 Mon Sep 17 00:00:00 2001 From: puhuk Date: Tue, 9 Nov 2021 02:36:37 +0900 Subject: [PATCH] Remove deprecated accelerator pass through functions in Accelerator (#10403) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rohit Gupta Co-authored-by: Adrian Wälchli Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 3 + pytorch_lightning/accelerators/accelerator.py | 334 +----------------- tests/deprecated_api/test_remove_1-6.py | 79 ----- 3 files changed, 4 insertions(+), 412 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23e9667cddbe4e..d88a299704525b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -90,6 +90,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated method `master_params` from PrecisionPlugin ([#10372](https://github.com/PyTorchLightning/pytorch-lightning/pull/10372)) +- Removed deprecated passthrough methods and properties from `Accelerator` base class ([#10403](https://github.com/PyTorchLightning/pytorch-lightning/pull/10403)) + + ### Fixed - Fixed `apply_to_collection(defaultdict)` ([#10316](https://github.com/PyTorchLightning/pytorch-lightning/issues/10316)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 6fb56ea802b1dc..eb1452ef1dfaea 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -13,14 +13,13 @@ # limitations under the License. import contextlib from abc import abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Union import torch from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn import Module from torch.optim import Optimizer -from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin @@ -59,21 +58,6 @@ def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: Trai self.lr_schedulers: List = [] self.optimizer_frequencies: List = [] - def connect(self, model: "pl.LightningModule") -> None: - """Transfers ownership of the model to this plugin. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_train_batch_start` directly. - """ - rank_zero_deprecation( - "`Accelerator.connect` is deprecated in v1.5 and will be removed in v1.6. " - "`connect` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - self.training_type_plugin.connect(model) - def setup_environment(self) -> None: """Setup any processes or distributed connections. @@ -215,18 +199,6 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: with self.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) - def post_training_step(self) -> None: - """ - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.post_training_step` directly. - """ - rank_zero_deprecation( - "`Accelerator.post_training_step` is deprecated in v1.5 and will be removed in v1.6. " - "`post_training_step` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - self.training_type_plugin.post_training_step() - def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual validation step. @@ -251,54 +223,6 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: with self.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) - def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: - """A hook to do something at the end of the training step. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.training_step_end` directly. - - Args: - output: the output of the training step - """ - rank_zero_deprecation( - "`Accelerator.training_step_end` is deprecated in v1.5 and will be removed in v1.6. " - "`training_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.training_step_end(output) - - def test_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: - """A hook to do something at the end of the test step. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.test_step_end` directly. - - Args: - output: the output of the test step - """ - rank_zero_deprecation( - "`Accelerator.test_step_end` is deprecated in v1.5 and will be removed in v1.6. " - "`test_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.test_step_end(output) - - def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: - """A hook to do something at the end of the validation step. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.validation_step_end` directly. - - Args: - output: the output of the validation step - """ - rank_zero_deprecation( - "`Accelerator.validation_step_end` is deprecated in v1.5 and will be removed in v1.6. " - "`validation_step_end` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.validation_step_end(output) - def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: """Forwards backward-calls to the precision plugin. @@ -389,104 +313,6 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """ return getattr(self.training_type_plugin, "optimizer_state", lambda x: x.state_dict())(optimizer) - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: - """Returns state of model. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.lightning_module_state_dict` directly. - - Allows for syncing/collating model state from processes in custom plugins. - """ - rank_zero_deprecation( - "`Accelerator.lightning_module_state_dict` is deprecated in v1.5 and will be removed in v1.6. " - "`lightning_module_state_dict` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.lightning_module_state_dict() - - def barrier(self, name: Optional[str] = None) -> None: - """ - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.barrier` directly. - """ - rank_zero_deprecation( - "`Accelerator.barrier` is deprecated in v1.5 and will be removed in v1.6. " - "`Barrier` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - self.training_type_plugin.barrier(name=name) - - def broadcast(self, obj: object, src: int = 0) -> object: - """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if - needed. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.broadcast` directly. - - Args: - obj: Object to broadcast to all process, usually a tensor or collection of tensors. - src: The source rank of which the object will be broadcast from - """ - rank_zero_deprecation( - "`Accelerator.broadcast` is deprecated in v1.5 and will be removed in v1.6. " - "`Broadcast` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.broadcast(obj, src) - - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """Function to gather a tensor from several distributed processes. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.all_gather` directly. - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - rank_zero_deprecation( - "`Accelerator.all_gather` is deprecated in v1.5 and will be removed in v1.6. " - "`all_gather` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) - - def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Wraps the dataloader if necessary. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.process_dataloader` directly. - - Args: - dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` - """ - rank_zero_deprecation( - "`Accelerator.process_dataloader` is deprecated in v1.5 and will be removed in v1.6. " - "`process_dataloader` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.process_dataloader(dataloader) - - @property - def results(self) -> Any: - """The results of the last run will be cached within the training type plugin. - - .. deprecated:: v1.5 - This property is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.results` directly. - - In distributed training, we make sure to transfer the results to the appropriate main process. - """ - rank_zero_deprecation( - "`Accelerator.results` is deprecated in v1.5 and will be removed in v1.6. " - "Accesse results directly from the `TrainingTypePlugin`." - ) - return self.training_type_plugin.results - @contextlib.contextmanager def model_sharded_context(self) -> Generator[None, None, None]: """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to. @@ -517,43 +343,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: ) self.training_type_plugin.save_checkpoint(checkpoint, filepath) - @property - def setup_optimizers_in_pre_dispatch(self) -> bool: - """Override to delay setting optimizers and schedulers till after dispatch. This is useful when the - `TrainingTypePlugin` requires operating on the wrapped accelerator model. However this may break certain - precision plugins such as APEX which require optimizers to be set. - - .. deprecated:: v1.5 - This property is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.setup_optimizers_in_pre_dispatch` directly. - - Returns: - If True, delay setup optimizers until `pre_dispatch`, else call within `setup`. - """ - rank_zero_deprecation( - "`Accelerator.setup_optimizers_in_pre_dispatch` is deprecated in v1.5 and will be removed in v1.6. " - "Accesse `setup_optimizers_in_pre_dispatch directly` from the `TrainingTypePlugin`." - ) - return self.training_type_plugin.setup_optimizers_in_pre_dispatch - - @property - def restore_checkpoint_after_pre_dispatch(self) -> bool: - """Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin - requires all the setup hooks to run before loading checkpoint. - - .. deprecated:: v1.5 - This property is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.restore_checkpoint_after_pre_dispatch` directly. - - Returns: - If true, restore checkpoint after pre_dispatch. - """ - rank_zero_deprecation( - "`Accelerator.restore_checkpoint_after_pre_dispatch` is deprecated in v1.5 and will be removed in v1.6." - " Access `restore_checkpoint_after_pre_dispatch` directly from the `TrainingTypePlugin`." - ) - return self.training_type_plugin.restore_checkpoint_after_pre_dispatch - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for a given device. @@ -569,127 +358,6 @@ def on_train_start(self) -> None: """Called when train begins.""" return self.training_type_plugin.on_train_start() - def on_validation_start(self) -> None: - """Called when validation begins. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_validation_start` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_validation_start` is deprecated in v1.5 and will be removed in v1.6. " - "`on_validation_start` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.on_validation_start() - - def on_test_start(self) -> None: - """Called when test begins. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_test_start` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_test_start` is deprecated in v1.5 and will be removed in v1.6. " - "`on_test_start` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.on_test_start() - - def on_predict_start(self) -> None: - """Called when predict begins. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_predict_start` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_predict_start` is deprecated in v1.5 and will be removed in v1.6. " - "`on_predict_start` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.on_predict_start() - - def on_validation_end(self) -> None: - """Called when validation ends. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_validation_end` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_validation_end` is deprecated in v1.5 and will be removed in v1.6. " - "`on_validation_end` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.on_validation_end() - - def on_test_end(self) -> None: - """Called when test end. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_test_end` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_test_end` is deprecated in v1.5 and will be removed in v1.6. " - "`on_test_end` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.on_test_end() - - def on_predict_end(self) -> None: - """Called when predict ends. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_predict_end` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_predict_end` is deprecated in v1.5 and will be removed in v1.6. " - "`on_predict_end` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.on_predict_end() - - def on_train_end(self) -> None: - """Called when train ends. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_train_end` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_train_end` is deprecated in v1.5 and will be removed in v1.6. " - "`on_train_end` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.on_train_end() - - # TODO: Update this in v1.7 (deprecation: #9816) - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - """Called in the training loop before anything happens for that batch. - - See deprecation warning below. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.6. - Please call `training_type_plugin.on_train_batch_start` directly. - """ - rank_zero_deprecation( - "`Accelerator.on_train_batch_start` is deprecated in v1.5 and will be removed in v1.6. " - "`on_train_batch_start` logic is implemented directly in the `TrainingTypePlugin` implementations." - ) - return self.training_type_plugin.on_train_batch_start(batch, batch_idx) - @staticmethod @abstractmethod def auto_device_count() -> int: diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 44b58908349acb..0ed1a774db1c46 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -15,7 +15,6 @@ from unittest.mock import call, Mock import pytest -import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks.early_stopping import EarlyStopping @@ -179,81 +178,3 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import(): _soft_unimport_module("pytorch_lightning.utilities.device_dtype_mixin") with pytest.deprecated_call(match="will be removed in v1.6"): from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401 - - -def test_v1_6_0_deprecated_accelerator_pass_through_functions(): - from pytorch_lightning.plugins.precision import PrecisionPlugin - from pytorch_lightning.plugins.training_type import SingleDevicePlugin - - plugin = SingleDevicePlugin(torch.device("cpu")) - from pytorch_lightning.accelerators.accelerator import Accelerator - - accelerator = Accelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.barrier() - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.broadcast(1) - - with pytest.deprecated_call(match="will be removed in v1.6"): - tensor = torch.rand(2, 2, requires_grad=True) - accelerator.all_gather(tensor) - - with pytest.deprecated_call(match="will be removed in v1.6"): - model = BoringModel() - accelerator.connect(model) - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.post_training_step() - - with pytest.deprecated_call(match="will be removed in v1.6"): - tensor = torch.rand(2, 2, requires_grad=True) - accelerator.training_step_end(tensor) - - with pytest.deprecated_call(match="will be removed in v1.6"): - tensor = torch.rand(2, 2, requires_grad=True) - accelerator.test_step_end(tensor) - - with pytest.deprecated_call(match="will be removed in v1.6"): - tensor = torch.rand(2, 2, requires_grad=True) - accelerator.validation_step_end(tensor) - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.lightning_module_state_dict() - - with pytest.deprecated_call(match="will be removed in v1.6"): - dl = model.train_dataloader() - accelerator.process_dataloader(dl) - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.results - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.setup_optimizers_in_pre_dispatch - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.restore_checkpoint_after_pre_dispatch - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_validation_start() - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_test_start() - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_predict_start() - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_validation_end() - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_test_end() - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_predict_end() - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_train_end() - - with pytest.deprecated_call(match="will be removed in v1.6"): - accelerator.on_train_batch_start(batch=None, batch_idx=0)