diff --git a/CHANGELOG.md b/CHANGELOG.md index 629b28e392792..ebf7cec509cff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,6 +205,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369)) +- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837)) + + +- + + +- + + ## [1.5.4] - 2021-11-30 ### Fixed diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 969a038776f94..d9f6ea8c0b181 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -32,7 +32,7 @@ def __init__(self) -> None: self._results = ResultCollection(training=False) self._outputs: List[EPOCH_OUTPUT] = [] - self._max_batches: List[Union[int, float]] = [] + self._max_batches: List[int] = [] self._has_run: bool = False @property @@ -141,7 +141,7 @@ def teardown(self) -> None: self._results.cpu() self.epoch_loop.teardown() - def _get_max_batches(self) -> List[Union[int, float]]: + def _get_max_batches(self) -> List[int]: """Returns the max number of batches for each dataloader.""" if self.trainer.testing: max_batches = self.trainer.num_test_batches diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 903fe4b26e3f0..11a1fc9b5cdaa 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -53,10 +53,7 @@ def num_dataloaders(self) -> int: @property def max_batches(self) -> List[int]: """The max number of batches this loop will run for each dataloader.""" - max_batches = self.trainer.num_predict_batches - if isinstance(max_batches, int): - max_batches = [max_batches] * len(self.dataloaders) - return max_batches + return self.trainer.num_predict_batches @property def dataloaders(self) -> Sequence[DataLoader]: diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 58f4a18895498..92bb9965dac4a 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -335,9 +335,24 @@ def _init_kineto(self, profiler_kwargs: Any) -> None: with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph self._profiler_kwargs["with_stack"] = with_stack + @property + def _total_steps(self) -> int: + trainer = self._lightning_module.trainer + if self._schedule.is_training: + return trainer.num_training_batches + if self._schedule._current_action == "validation_step": + return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches) + if self._schedule._current_action == "test_step": + return sum(trainer.num_test_batches) + if self._schedule._current_action == "predict_step": + return sum(trainer.num_predict_batches) + def _should_override_schedule(self) -> bool: - return (self._lightning_module is not None and self._lightning_module.trainer.limit_train_batches < 5) and ( - self._schedule is not None and self._schedule._schedule == self._default_schedule() + return ( + self._lightning_module is not None + and self._schedule is not None + and self._total_steps < 5 + and self._schedule._schedule == self._default_schedule() ) @staticmethod @@ -410,6 +425,9 @@ def stop(self, action_name: str) -> None: action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX) ): + if self._schedule is not None: + self._schedule.pre_step(action_name) + # the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`. # otherwise, this will raise a `segmentation fault`. if self._should_override_schedule(): @@ -420,9 +438,6 @@ def stop(self, action_name: str) -> None: self._schedule = None self.profiler.schedule = torch.profiler.profiler._default_schedule_fn - if self._schedule is not None: - self._schedule.pre_step(action_name) - def on_trace_ready(profiler): if self.dirpath is not None: if self._export_to_chrome: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 455c2719b124a..0b1efc535f00c 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -50,13 +50,18 @@ class TrainerDataLoadingMixin(ABC): val_check_interval: float tpu_local_core_rank: int train_dataloader: DataLoader - num_training_batches: Union[int, float] - val_check_batch: float - val_dataloaders: Optional[List[DataLoader]] - num_val_batches: List[Union[int, float]] - test_dataloaders: Optional[List[DataLoader]] - num_test_batches: List[Union[int, float]] limit_train_batches: Union[int, float] + num_training_batches: int + val_check_batch: float + val_dataloaders: List[DataLoader] + limit_val_batches: Union[int, float] + num_val_batches: List[int] + test_dataloaders: List[DataLoader] + limit_test_batches: Union[int, float] + num_test_batches: List[int] + predict_dataloaders: List[DataLoader] + limit_predict_batches: Union[int, float] + num_predict_batches: List[int] log_every_n_steps: int overfit_batches: Union[int, float] distributed_sampler_kwargs: dict diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 126a9a6d1dee6..3108c83d4da0a 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -26,7 +26,7 @@ from pytorch_lightning.loggers.base import LoggerCollection from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler -from pytorch_lightning.profiler.pytorch import RegisterRecordFunction +from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE from tests.helpers import BoringModel, ManualOptimBoringModel @@ -524,3 +524,31 @@ def test_trainer_profiler_incorrect_str_arg(): match=r"When passing string value for the `profiler` parameter of `Trainer`, it can only be one of.*", ): Trainer(profiler="unknown_profiler") + + +@pytest.mark.skipif(not _KINETO_AVAILABLE, reason="Requires PyTorch Profiler Kineto") +@pytest.mark.parametrize( + ["trainer_config", "trainer_fn"], + [ + ({"limit_train_batches": 4, "limit_val_batches": 7}, "fit"), + ({"limit_train_batches": 7, "limit_val_batches": 4, "num_sanity_val_steps": 0}, "fit"), + ( + { + "limit_train_batches": 7, + "limit_val_batches": 2, + }, + "fit", + ), + ({"limit_val_batches": 4}, "validate"), + ({"limit_test_batches": 4}, "test"), + ({"limit_predict_batches": 4}, "predict"), + ], +) +def test_pytorch_profiler_raises_warning_for_limited_steps(tmpdir, trainer_config, trainer_fn): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", max_epochs=1, **trainer_config) + warning_cache.clear() + with pytest.warns(UserWarning, match="not enough steps to properly record traces"): + getattr(trainer, trainer_fn)(model) + assert trainer.profiler._schedule is None + warning_cache.clear()