diff --git a/CHANGELOG.md b/CHANGELOG.md index c2c9ef7cb7f70..482525fd3542a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -158,8 +158,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `ModelIO.on_hpc_{save/load}` in favor of `CheckpointHooks.on_{save/load}_checkpoint` ([#10911](https://github.com/PyTorchLightning/pytorch-lightning/pull/10911)) -- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068)) +- Deprecated `Trainer.run_stage` in favor of `Trainer.{fit,validate,test,predict}` ([#11000](https://github.com/PyTorchLightning/pytorch-lightning/pull/11000)) + +- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068)) ### Removed diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fe19accb7dbc0..6aafb03791529 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1130,7 +1130,7 @@ def _run( setup accelerator || and strategy || LIGHTNING | || - {self.run_stage} || FLOW + {self._run_stage} || FLOW | || {self._run_train} || DIRECTION or {self._run_evaluate} || @@ -1165,7 +1165,7 @@ def _run( self.checkpoint_connector.resume_end() - results = self.run_stage() + results = self._run_stage() self._teardown() # ---------------------------- @@ -1238,7 +1238,14 @@ def _teardown(self): self.logger_connector.teardown() self.signal_connector.teardown() - def run_stage(self): + def run_stage(self) -> None: + rank_zero_deprecation( + "`Trainer.run_stage` is deprecated in v1.6 and will be removed in v1.8. Use" + " `Trainer.{fit,validate,test,predict}` instead." + ) + return self._run_stage() + + def _run_stage(self): self.training_type_plugin.barrier("run-stage") self.training_type_plugin.dispatch(self) self.__setup_profiler() diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 20c65ff81a1d6..5a14782d33036 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.8.0.""" +from unittest.mock import Mock import pytest import torch @@ -108,6 +109,13 @@ def on_hpc_load(self): trainer.fit(load_model) +def test_v1_8_0_deprecated_run_stage(): + trainer = Trainer() + trainer._run_stage = Mock() + with pytest.deprecated_call(match="`Trainer.run_stage` is deprecated in v1.6 and will be removed in v1.8."): + trainer.run_stage() + + def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir): trainer = Trainer() with pytest.deprecated_call(