Skip to content

Commit

Permalink
Mark Trainer.run_stage as protected (#11000)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
4 people authored Dec 17, 2021
1 parent c66cd12 commit 210ff84
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `ModelIO.on_hpc_{save/load}` in favor of `CheckpointHooks.on_{save/load}_checkpoint` ([#10911](https://github.com/PyTorchLightning/pytorch-lightning/pull/10911))


- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068))
- Deprecated `Trainer.run_stage` in favor of `Trainer.{fit,validate,test,predict}` ([#11000](https://github.com/PyTorchLightning/pytorch-lightning/pull/11000))


- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068))

### Removed

Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ def _run(
setup accelerator ||
and strategy || LIGHTNING
| ||
{self.run_stage} || FLOW
{self._run_stage} || FLOW
| ||
{self._run_train} || DIRECTION
or {self._run_evaluate} ||
Expand Down Expand Up @@ -1165,7 +1165,7 @@ def _run(

self.checkpoint_connector.resume_end()

results = self.run_stage()
results = self._run_stage()
self._teardown()

# ----------------------------
Expand Down Expand Up @@ -1238,7 +1238,14 @@ def _teardown(self):
self.logger_connector.teardown()
self.signal_connector.teardown()

def run_stage(self):
def run_stage(self) -> None:
rank_zero_deprecation(
"`Trainer.run_stage` is deprecated in v1.6 and will be removed in v1.8. Use"
" `Trainer.{fit,validate,test,predict}` instead."
)
return self._run_stage()

def _run_stage(self):
self.training_type_plugin.barrier("run-stage")
self.training_type_plugin.dispatch(self)
self.__setup_profiler()
Expand Down
8 changes: 8 additions & 0 deletions tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.8.0."""
from unittest.mock import Mock

import pytest
import torch
Expand Down Expand Up @@ -108,6 +109,13 @@ def on_hpc_load(self):
trainer.fit(load_model)


def test_v1_8_0_deprecated_run_stage():
trainer = Trainer()
trainer._run_stage = Mock()
with pytest.deprecated_call(match="`Trainer.run_stage` is deprecated in v1.6 and will be removed in v1.8."):
trainer.run_stage()


def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
trainer = Trainer()
with pytest.deprecated_call(
Expand Down

0 comments on commit 210ff84

Please sign in to comment.