diff --git a/python/ray/tune/experiment/trial.py b/python/ray/tune/experiment/trial.py index 0834181fdfb8..5ce42f808af4 100644 --- a/python/ray/tune/experiment/trial.py +++ b/python/ray/tune/experiment/trial.py @@ -793,11 +793,11 @@ def get_error(self) -> Optional[TuneError]: return None def _handle_restore_error(self, exc: Exception): + # For Restoration errors, we only increment the restore failure count + # if the number of failures exceeds the restore retry limit. if self.temporary_state.num_restore_failures >= int( os.environ.get("TUNE_RESTORE_RETRY_NUM", 0) ): - # Restore was unsuccessful, try again without checkpoint. - self.clear_checkpoint() self.run_metadata.num_failures += 1 else: self.temporary_state.num_restore_failures += 1 @@ -883,12 +883,6 @@ def should_checkpoint(self): def has_checkpoint(self) -> bool: return self.checkpoint is not None - def clear_checkpoint(self): - if self.latest_checkpoint_result: - self.latest_checkpoint_result.checkpoint = None - self.temporary_state.restoring_from = None - self.run_metadata.invalidate_cache() - def on_checkpoint(self, checkpoint_result: _TrainingResult): """Hook for handling checkpoints taken by the Trainable. diff --git a/python/ray/tune/tests/test_tuner_restore.py b/python/ray/tune/tests/test_tuner_restore.py index 158038f505a3..c2664eb217cd 100644 --- a/python/ray/tune/tests/test_tuner_restore.py +++ b/python/ray/tune/tests/test_tuner_restore.py @@ -537,7 +537,21 @@ def test_tuner_restore_latest_available_checkpoint( @pytest.mark.parametrize("retry_num", [0, 2]) def test_restore_retry(ray_start_2_cpus, tmpdir, retry_num): - """Test retrying restore on a trial level by setting `TUNE_RESTORE_RETRY_NUM`.""" + """ + Test retrying restore on a trial level by setting `TUNE_RESTORE_RETRY_NUM`. + + This unit test holds the following hyperparameters: + - `retry_num`: Maximum number of retry attempts for restoring a trial. + This value is assigned to the environment variable `TUNE_RESTORE_RETRY_NUM`. + If the restoration fails after retry_num attempts, the trial increments its + counter of total number of failures by 1. + + - `retry_num_to_fail`: Number of restore attempts to fail. In this test, + retry_num_to_fail is set to 2, causing the first two restore attempts to fail. + + - `max_failures`: Maximum allowable failures during training. Here, max_failures is + set to 2, meaning the training process will terminate after two total failures. + """ class MockTrainable(Trainable): """A trainable that can generate one failure during training and @@ -546,7 +560,7 @@ class MockTrainable(Trainable): def setup(self, config): self.idx = 0 self.tag_file_path = config["tag_file_path"] - self.retry_num_to_fail = config.get("retry_num_to_fail", 2) + self.retry_num_to_fail = 2 self._is_restored = False def step(self): @@ -592,7 +606,7 @@ def load_checkpoint(self, checkpoint_dir): name="tryout_restore", stop={"training_iteration": 5}, storage_path=str(tmpdir), - failure_config=FailureConfig(max_failures=1), + failure_config=FailureConfig(max_failures=2), checkpoint_config=CheckpointConfig(checkpoint_frequency=1), ), param_space={"tag_file_path": tag_file},